Spaces:
Sleeping
Sleeping
from vllm import LLM, SamplingParams | |
import json | |
import torch | |
import time | |
from datetime import datetime, timedelta | |
import argparse | |
from tqdm import tqdm | |
from typing import List, Dict, Any | |
import concurrent.futures | |
class VLLMGenerator: | |
def __init__(self, model_name: str, n: int = 8, max_tokens: int = 512, | |
temperature: float = 0.7, top_p: float = 1.0, | |
frequency_penalty: float = 0.0, presence_penalty: float = 0.0, | |
stop: List[str] = ['\n\n\n'], batch_size: int = 32): | |
self.device_count = torch.cuda.device_count() | |
print(f"Initializing with {self.device_count} GPUs") | |
self.llm = LLM( | |
model=model_name, | |
tensor_parallel_size=self.device_count, | |
max_model_len=4096, | |
gpu_memory_utilization=0.95, | |
enforce_eager=True, | |
trust_remote_code=True, | |
# quantization="bitsandbytes", | |
# dtype="half", | |
# load_format="bitsandbytes", | |
max_num_batched_tokens=4096, | |
max_num_seqs=batch_size | |
) | |
self.sampling_params = SamplingParams( | |
n=n, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
frequency_penalty=frequency_penalty, | |
presence_penalty=presence_penalty, | |
stop=stop, | |
logprobs=1 | |
) | |
self.batch_size = batch_size | |
self.tokenizer = self.llm.get_tokenizer() | |
print(f"Initialization complete. Batch size: {batch_size}") | |
def parse_response(self, responses): | |
all_outputs = [] | |
for response in responses: | |
to_return = [] | |
for output in response.outputs: | |
text = output.text.strip() | |
try: | |
logprob = sum(logprob_obj.logprob for item in output.logprobs for logprob_obj in item.values()) | |
except: | |
logprob = 0 # Fallback if logprobs aren't available | |
to_return.append((text, logprob)) | |
texts = [r[0] for r in sorted(to_return, key=lambda tup: tup[1], reverse=True)] | |
all_outputs.append(texts) | |
return all_outputs | |
def prepare_prompt(self, claim: str, model_name: str) -> str: | |
base_prompt = f"Please write a fact-checking article passage to support, refute, indicate not enough evidence, or present conflicting evidence regarding the claim.\nClaim: {claim}" | |
if "OLMo" in model_name: | |
return base_prompt | |
else: | |
messages = [{"role": "user", "content": base_prompt}] | |
return self.tokenizer.apply_chat_template(messages, tokenize=False) + "<|start_header_id|>assistant<|end_header_id|>\n\nPassage: " | |
def process_batch(self, batch: List[Dict[str, Any]], model_name: str) -> tuple[List[Dict[str, Any]], float]: | |
start_time = time.time() | |
prompts = [self.prepare_prompt(example["claim"], model_name) for example in batch] | |
try: | |
results = self.llm.generate(prompts, sampling_params=self.sampling_params) | |
outputs = self.parse_response(results) | |
for example, output in zip(batch, outputs): | |
example['hypo_fc_docs'] = output | |
batch_time = time.time() - start_time | |
return batch, batch_time | |
except Exception as e: | |
print(f"Error processing batch: {str(e)}") | |
return batch, time.time() - start_time | |
# def format_time(seconds: float) -> str: | |
# return str(timedelta(seconds=int(seconds))) | |
# def estimate_completion_time(start_time: float, processed_examples: int, total_examples: int) -> str: | |
# elapsed_time = time.time() - start_time | |
# examples_per_second = processed_examples / elapsed_time | |
# remaining_examples = total_examples - processed_examples | |
# estimated_remaining_seconds = remaining_examples / examples_per_second | |
# completion_time = datetime.now() + timedelta(seconds=int(estimated_remaining_seconds)) | |
# return completion_time.strftime("%Y-%m-%d %H:%M:%S") | |
def main(args): | |
total_start_time = time.time() | |
print(f"Script started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
# Load data | |
print("Loading data...") | |
with open(args.target_data, 'r', encoding='utf-8') as json_file: | |
examples = json.load(json_file) | |
print(f"Loaded {len(examples)} examples") | |
# Initialize generator | |
print("Initializing generator...") | |
generator = VLLMGenerator( | |
model_name=args.model, | |
batch_size=32 | |
) | |
# Process data in batches | |
processed_data = [] | |
# batch_times = [] | |
batches = [examples[i:i + generator.batch_size] for i in range(0, len(examples), generator.batch_size)] | |
print(f"\nProcessing {len(batches)} batches...") | |
with tqdm(total=len(examples), desc="Processing examples") as pbar: | |
for batch_idx, batch in enumerate(batches, 1): | |
processed_batch, batch_time = generator.process_batch(batch, args.model) | |
processed_data.extend(processed_batch) | |
# batch_times.append(batch_time) | |
# Update progress and timing information | |
# examples_processed = len(processed_data) | |
# avg_batch_time = sum(batch_times) / len(batch_times) | |
# estimated_completion = estimate_completion_time(total_start_time, examples_processed, len(examples)) | |
# pbar.set_postfix({ | |
# 'Batch': f"{batch_idx}/{len(batches)}", | |
# 'Avg Batch Time': f"{avg_batch_time:.2f}s", | |
# 'ETA': estimated_completion | |
# }) | |
# pbar.update(len(batch)) | |
# Calculate and display timing statistics | |
# total_time = time.time() - total_start_time | |
# avg_batch_time = sum(batch_times) / len(batch_times) | |
# avg_example_time = total_time / len(examples) | |
# print("\nTiming Statistics:") | |
# print(f"Total Runtime: {format_time(total_time)}") | |
# print(f"Average Batch Time: {avg_batch_time:.2f} seconds") | |
# print(f"Average Time per Example: {avg_example_time:.2f} seconds") | |
# print(f"Throughput: {len(examples)/total_time:.2f} examples/second") | |
# Save results | |
# print("\nSaving results...") | |
with open(args.json_output, "w", encoding="utf-8") as output_json: | |
json.dump(processed_data, output_json, ensure_ascii=False, indent=4) | |
# print(f"Script completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
# print(f"Total runtime: {format_time(total_time)}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-i', '--target_data', default='data_store/averitec/dev.json') | |
parser.add_argument('-o', '--json_output', default='data_store/hyde_fc.json') | |
parser.add_argument('-m', '--model', default="meta-llama/Llama-3.1-8B-Instruct") | |
args = parser.parse_args() | |
main(args) |