Spaces:
Sleeping
Sleeping
import os | |
import argparse | |
import time | |
import json | |
import nltk | |
from rank_bm25 import BM25Okapi | |
import numpy as np | |
import torch | |
from vllm import LLM, SamplingParams | |
from datetime import datetime, timedelta | |
from itertools import islice | |
def download_nltk_data(package_name, download_dir='nltk_data'): | |
# Ensure the download directory exists | |
os.makedirs(download_dir, exist_ok=True) | |
# Set NLTK data path | |
nltk.data.path.append(download_dir) | |
try: | |
# Try to find the resource | |
nltk.data.find(f'tokenizers/{package_name}') | |
print(f"Package '{package_name}' is already downloaded") | |
except LookupError: | |
# If resource isn't found, download it | |
print(f"Downloading {package_name}...") | |
nltk.download(package_name, download_dir=download_dir) | |
print(f"Successfully downloaded {package_name}") | |
# def format_time(seconds): | |
# """Format time duration nicely.""" | |
# return str(timedelta(seconds=round(seconds))) | |
def claim2prompts(example): | |
claim = example["claim"] | |
claim_str = "Example [NUMBER]:||Claim: " + claim + "||Evidence: " | |
for question in example["questions"]: | |
q_text = question["question"].strip() | |
if len(q_text) == 0: | |
continue | |
if not q_text[-1] == "?": | |
q_text += "?" | |
answer_strings = [] | |
for a in question["answers"]: | |
if a["answer_type"] in ["Extractive", "Abstractive"]: | |
answer_strings.append(a["answer"]) | |
if a["answer_type"] == "Boolean": | |
answer_strings.append(a["answer"] + ", because " + a["boolean_explanation"].lower().strip()) | |
for a_text in answer_strings: | |
if not a_text[-1] in [".", "!", ":", "?"]: | |
a_text += "." | |
prompt_lookup_str = a_text | |
this_q_claim_str = claim_str + a_text.strip() + "||Question: " + q_text | |
yield (prompt_lookup_str, this_q_claim_str.replace("\n", " ").replace("||", "\n")[:1500]) | |
def main(args): | |
# script_start = time.time() | |
# start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
# print(f"Script started at: {start_time}") | |
# print(f"Loading model: {args.model}") | |
download_nltk_data('punkt') | |
download_nltk_data('punkt_tab') | |
# Load and prepare reference corpus | |
# corpus_start = time.time() | |
with open(args.reference_corpus, "r", encoding="utf-8") as json_file: | |
train_examples = json.load(json_file) | |
prompt_corpus, tokenized_corpus = [], [] | |
for example in train_examples: | |
for lookup_str, prompt in claim2prompts(example): | |
entry = nltk.word_tokenize(lookup_str) | |
tokenized_corpus.append(entry) | |
prompt_corpus.append(prompt) | |
prompt_bm25 = BM25Okapi(tokenized_corpus) | |
# print(f"Reference corpus processed in: {format_time(time.time() - corpus_start)}") | |
# Initialize vLLM with optimized settings | |
gpu_count = torch.cuda.device_count() | |
print(f"Using {gpu_count} GPU{'s' if gpu_count > 1 else ''}") | |
# model_start = time.time() | |
llm = LLM( | |
model=args.model, | |
tensor_parallel_size=gpu_count, | |
max_model_len=4096, | |
gpu_memory_utilization=0.95, | |
enforce_eager=True, | |
trust_remote_code=True, | |
# dtype="half", | |
) | |
llm.get_tokenizer().pad_token = "<|end_of_text|>" | |
# print(f"Model loaded in: {format_time(time.time() - model_start)}") | |
sampling_params = SamplingParams( | |
temperature=0.6, | |
top_p=0.9, | |
top_k=1, | |
skip_special_tokens=False, | |
max_tokens=512, | |
stop=['<|end_of_text|>', '</s>', '<|im_end|>', '[INST]', '[/INST]','<|eot_id|>','<|end|>','<|endoftext|>'] | |
) | |
# processing_start = time.time() | |
# Load target data | |
target_examples = [] | |
with open(args.top_k_target_knowledge, "r", encoding="utf-8") as json_file: | |
for line in json_file: | |
target_examples.append(json.loads(line)) | |
if args.end == -1: | |
args.end = len(target_examples) | |
print(f"Processing {args.end} examples") | |
# Process in batches | |
with torch.no_grad(): | |
with open(args.output_questions, "w", encoding="utf-8") as output_file: | |
for idx in range(0, args.end, args.batch_size): | |
batch_end = min(idx + args.batch_size, args.end) | |
current_batch = target_examples[idx:batch_end] | |
print(f"\nProcessing batch {idx}-{batch_end}...") | |
for example in current_batch: | |
# batch_start = time.time() | |
claim = example["claim"] | |
claim_id = example["claim_id"] | |
top_k_sentences_urls = example[f"top_{args.top_k}"] | |
batch_prompts = [] | |
batch_metadata = [] | |
# Prepare all prompts for current example | |
for sentences_urls in top_k_sentences_urls: | |
prompt_lookup_str = sentences_urls["sentence"] | |
url = sentences_urls["url"] | |
prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str)) | |
prompt_n = 10 | |
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n] | |
prompt_docs = [prompt_corpus[i] for i in prompt_top_n] | |
temp_prompt = "\n\n".join(prompt_docs) | |
for k in range(1, temp_prompt.count("[NUMBER]")+1): | |
temp_prompt = temp_prompt.replace("[NUMBER]", f"{k}", 1) | |
claim_prompt = "Your task is to generate a question based on the given claim and evidence. The question should clarify the relationship between the evidence and the claim\n\n" | |
evidence = prompt_lookup_str.replace("\n", " ") | |
full_prompt = claim_prompt + temp_prompt + "\n\nNow, generate a question that links the following claim and evidence:" + f"\n\nClaim: {claim}" + f"\nEvidence: {evidence}" | |
if "OLMo" in args.model: | |
inputs = [full_prompt] | |
else: | |
messages = [{"role":"user", "content":full_prompt}] | |
inputs = llm.get_tokenizer().apply_chat_template(messages, tokenize=False) | |
inputs += "<|start_header_id|>assistant<|end_header_id|>\n\nQuestion: " | |
batch_prompts.append(inputs) | |
batch_metadata.append((url, prompt_lookup_str)) | |
# Process batch | |
outputs = llm.generate(batch_prompts, sampling_params) | |
# Process outputs | |
evidence = [] | |
for output, (url, sent) in zip(outputs, batch_metadata): | |
question = output.outputs[0].text.strip().split("?")[0].replace("\n", " ") + "?" | |
evidence.append({ | |
"question": question, | |
"answer": sent, | |
"url": url | |
}) | |
# Write results | |
json_data = { | |
"claim_id": claim_id, | |
"claim": claim, | |
"evidence": evidence | |
} | |
output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n") | |
output_file.flush() | |
# batch_time = time.time() - batch_start | |
# print(f"Processed example {claim_id}. Time elapsed: {batch_time:.2f}s") | |
# Calculate and display timing information | |
# total_time = time.time() - script_start | |
# processing_time = time.time() - processing_start | |
# end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
# print("\nTiming Summary:") | |
# print(f"Start time: {start_time}") | |
# print(f"End time: {end_time}") | |
# print(f"Total runtime: {format_time(total_time)}") | |
# print(f"Setup time: {format_time(processing_start - script_start)}") | |
# print(f"Processing time: {format_time(processing_time)}") | |
# print(f"Results written to: {args.output_questions}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Use a prompt to generate questions that could be answered by top-k retrieved evidence. Output generated questions.") | |
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") | |
parser.add_argument("--reference_corpus", default="baseline/train.json") | |
parser.add_argument( | |
"-i", | |
"--top_k_target_knowledge", | |
default="data_store/dev_reranking_top_k.json", | |
help="Directory where the sentences for the scraped data is saved.", | |
) | |
parser.add_argument( | |
"-o", | |
"--output_questions", | |
default="data_store/dev_top_k_qa.json", | |
help="Directory where the sentences for the scraped data is saved.", | |
) | |
parser.add_argument( | |
"--top_k", | |
default=10, | |
type=int | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
default=4, | |
help="Number of examples to process in each batch" | |
) | |
parser.add_argument( | |
"-e", | |
"--end", | |
type=int, | |
default=-1 | |
) | |
args = parser.parse_args() | |
main(args) |