import argparse import json import os import time import numpy as np import pandas as pd import nltk from rank_bm25 import BM25Okapi from multiprocessing import Pool, cpu_count, Manager, Lock from functools import partial import heapq from threading import Thread, Event import queue from datetime import datetime, timedelta def download_nltk_data(package_name, download_dir='nltk_data'): # Ensure the download directory exists os.makedirs(download_dir, exist_ok=True) # Set NLTK data path nltk.data.path.append(download_dir) try: # Try to find the resource nltk.data.find(f'tokenizers/{package_name}') print(f"Package '{package_name}' is already downloaded") except LookupError: # If resource isn't found, download it print(f"Downloading {package_name}...") nltk.download(package_name, download_dir=download_dir) print(f"Successfully downloaded {package_name}") def combine_all_sentences(knowledge_file): sentences, urls = [], [] with open(knowledge_file, "r", encoding="utf-8") as json_file: for i, line in enumerate(json_file): data = json.loads(line) sentences.extend(data["url2text"]) urls.extend([data["url"] for _ in range(len(data["url2text"]))]) return sentences, urls, i + 1 def remove_duplicates(sentences, urls): df = pd.DataFrame({"document_in_sentences":sentences, "sentence_urls":urls}) df['sentences'] = df['document_in_sentences'].str.strip().str.lower() df = df.drop_duplicates(subset="sentences").reset_index() return df['document_in_sentences'].tolist(), df['sentence_urls'].tolist() def retrieve_top_k_sentences(query, document, urls, top_k): tokenized_docs = [nltk.word_tokenize(doc) for doc in document[:top_k]] bm25 = BM25Okapi(tokenized_docs) scores = bm25.get_scores(nltk.word_tokenize(query)) top_k_idx = np.argsort(scores)[::-1][:top_k] return [document[i] for i in top_k_idx], [urls[i] for i in top_k_idx] def process_single_example(idx, example, args, result_queue, counter, lock): try: with lock: current_count = counter.value + 1 counter.value = current_count print(f"\nProcessing claim {idx}... Progress: {current_count} / {args.total_examples}") # start_time = time.time() document_in_sentences, sentence_urls, num_urls_this_claim = combine_all_sentences( os.path.join(args.knowledge_store_dir, f"{idx}.jsonl") ) print(f"Obtained {len(document_in_sentences)} sentences from {num_urls_this_claim} urls.") document_in_sentences, sentence_urls = remove_duplicates(document_in_sentences, sentence_urls) query = example["claim"] + " " + " ".join(example['hypo_fc_docs']) top_k_sentences, top_k_urls = retrieve_top_k_sentences( query, document_in_sentences, sentence_urls, args.top_k ) result = { "claim_id": idx, "claim": example["claim"], f"top_{args.top_k}": [ {"sentence": sent, "url": url} for sent, url in zip(top_k_sentences, top_k_urls) ], "hypo_fc_docs": example['hypo_fc_docs'] } result_queue.put((idx, result)) return True except Exception as e: print(f"Error processing example {idx}: {str(e)}") result_queue.put((idx, None)) return False def writer_thread(output_file, result_queue, total_examples, stop_event): next_index = 0 pending_results = [] with open(output_file, "w", encoding="utf-8") as f: while not (stop_event.is_set() and result_queue.empty()): try: idx, result = result_queue.get(timeout=1) if result is not None: heapq.heappush(pending_results, (idx, result)) while pending_results and pending_results[0][0] == next_index: _, result_to_write = heapq.heappop(pending_results) f.write(json.dumps(result_to_write, ensure_ascii=False) + "\n") f.flush() next_index += 1 except queue.Empty: continue # def format_time(seconds): # """Format time duration nicely.""" # return str(timedelta(seconds=round(seconds))) def main(args): download_nltk_data('punkt') download_nltk_data('punkt_tab') with open(args.target_data, "r", encoding="utf-8") as json_file: target_examples = json.load(json_file) if args.end == -1: args.end = len(target_examples) print(f"Total examples to process: {args.end - args.start}") files_to_process = list(range(args.start, args.end)) examples_to_process = [(idx, target_examples[idx]) for idx in files_to_process] num_workers = min(args.workers if args.workers > 0 else cpu_count(), len(files_to_process)) print(f"Using {num_workers} workers to process {len(files_to_process)} examples") with Manager() as manager: counter = manager.Value('i', 0) lock = manager.Lock() args.total_examples = len(files_to_process) result_queue = manager.Queue() stop_event = Event() writer = Thread( target=writer_thread, args=(args.json_output, result_queue, len(files_to_process), stop_event) ) writer.start() process_func = partial( process_single_example, args=args, result_queue=result_queue, counter=counter, lock=lock ) with Pool(num_workers) as pool: results = pool.starmap(process_func, examples_to_process) stop_event.set() writer.join() # successful = sum(1 for r in results if r) # print(f"\nSuccessfully processed {successful} out of {len(files_to_process)} examples") # print(f"Results written to {args.json_output}") # # Calculate and display timing information # total_time = time.time() - script_start # avg_time = total_time / len(files_to_process) # end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") # print("\nTiming Summary:") # print(f"Start time: {start_time}") # print(f"End time: {end_time}") # print(f"Total runtime: {format_time(total_time)} (HH:MM:SS)") # print(f"Average time per example: {avg_time:.2f} seconds") # if successful > 0: # print(f"Processing speed: {successful / total_time:.2f} examples per second") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Get top 10000 sentences with BM25 in the knowledge store using parallel processing." ) parser.add_argument( "-k", "--knowledge_store_dir", type=str, default="data_store/knowledge_store", help="The path of the knowledge_store_dir containing json files with all the retrieved sentences.", ) parser.add_argument( "--target_data", type=str, default="data_store/hyde_fc.json", help="The path of the file that stores the claim.", ) parser.add_argument( "-o", "--json_output", type=str, default="data_store/dev_retrieval_top_k.json", help="The output dir for JSON files to save the top 100 sentences for each claim.", ) parser.add_argument( "--top_k", default=5000, type=int, help="How many documents should we pick out with BM25.", ) parser.add_argument( "-s", "--start", type=int, default=0, help="Starting index of the files to process.", ) parser.add_argument( "-e", "--end", type=int, default=-1, help="End index of the files to process.", ) parser.add_argument( "-w", "--workers", type=int, default=0, help="Number of worker processes (default: number of CPU cores)", ) args = parser.parse_args() main(args)