Pledge_Tracker / system /baseline /reranking_optimized.py
yulongchen's picture
Add system
fcd14e1
import os
import torch
import gc
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
import json
import argparse
import time
from datetime import datetime, timedelta
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
def encode_text(model, tokenizer, texts, batch_size=8, max_length=512):
"""Encode texts to embeddings using AutoModel"""
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
# Tokenize
encoded_input = tokenizer(
batch,
padding=True,
truncation=True,
max_length=max_length,
return_tensors='pt'
).to(model.device)
# Compute token embeddings
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
model_output = model(**encoded_input)
# Use mean pooling
attention_mask = encoded_input['attention_mask']
token_embeddings = model_output[0] # First element contains token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
all_embeddings.append(embeddings.cpu().numpy())
# Clear some memory
if i % (batch_size * 4) == 0:
torch.cuda.empty_cache()
gc.collect()
return np.vstack(all_embeddings)
def compute_similarity(emb1, emb2):
"""Compute cosine similarity between embeddings"""
return np.dot(emb1, emb2.T) / (
np.linalg.norm(emb1, axis=1).reshape(-1, 1) *
np.linalg.norm(emb2, axis=1).reshape(1, -1)
)
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery: {query}'
def preprocess_sentences(sentence1, sentence2):
vectorizer = TfidfVectorizer().fit_transform([sentence1, sentence2])
vectors = vectorizer.toarray()
cosine_sim = cosine_similarity(vectors)
similarity_score = cosine_sim[0][1]
return similarity_score
def remove_trailing_special_chars(text):
return re.sub(r'[\W_]+$', '', text)
def remove_special_chars_except_spaces(text):
return re.sub(r'[^\w\s]+', '', text)
def select_top_k(claim, results, top_k):
'''
remove sentence of similarity claim
'''
dup_check = set()
top_k_sentences_urls = []
i = 0
# print(results)
claim = remove_special_chars_except_spaces(claim).lower()
while len(top_k_sentences_urls) < top_k and i < len(results):
# print(i)
sentence = remove_special_chars_except_spaces(results[i]['sentence']).lower()
if sentence not in dup_check:
if preprocess_sentences(claim, sentence) > 0.97:
dup_check.add(sentence)
continue
if claim in sentence:
if len(claim) / len(sentence) > 0.92:
dup_check.add(sentence)
continue
top_k_sentences_urls.append({
'sentence': results[i]['sentence'],
'url': results[i]['url']}
)
i += 1
return top_k_sentences_urls
# def format_time(seconds):
# """Format time duration nicely."""
# return str(timedelta(seconds=round(seconds)))
def compute_embeddings_batched(model, texts, batch_size=8):
"""Compute embeddings in smaller batches to manage memory"""
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Use bfloat16
emb = model.encode(batch, batch_size=len(batch), show_progress_bar=False)
all_embeddings.append(emb)
# Clear some memory
if i % (batch_size * 4) == 0:
torch.cuda.empty_cache()
gc.collect()
return np.vstack(all_embeddings)
def main(args):
device = "cuda" if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Load model and tokenizer
model = AutoModel.from_pretrained(
"Salesforce/SFR-Embedding-2_R",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Salesforce/SFR-Embedding-2_R")
# Load target examples
target_examples = []
with open(args.target_data, "r", encoding="utf-8") as json_file:
for i, line in enumerate(json_file):
try:
example = json.loads(r"{}".format(line))
target_examples.append(example)
except:
print(f"CURRENT LINE broken {i}")
if args.end == -1:
args.end = len(target_examples)
files_to_process = list(range(args.start, args.end))
total = len(files_to_process)
task = 'Given a web search query, retrieve relevant passages that answer the query'
with open(args.json_output, "w", encoding="utf-8") as output_json:
done = 0
for idx, example in enumerate(target_examples):
if idx in files_to_process:
print(f"Processing claim {example['claim_id']}... Progress: {done + 1} / {total}")
claim = example['claim']
query = [get_detailed_instruct(task, claim)] + [
get_detailed_instruct(task, le)
for le in example['hypo_fc_docs']
if len(le.strip()) > 0
]
query_length = len(query)
sentences = [sent['sentence'] for sent in example[f'top_{5000}']][:args.retrieved_top_k]
# st = time.time()
try:
# Process query embeddings
query_embeddings = encode_text(model, tokenizer, query, batch_size=4)
avg_emb_q = np.mean(query_embeddings, axis=0)
hyde_vector = avg_emb_q.reshape((1, -1))
# Process sentence embeddings in smaller chunks
sentence_embeddings = encode_text(
model,
tokenizer,
sentences,
batch_size=args.batch_size
)
# Compute similarities in chunks to save memory
chunk_size = 1000
all_scores = []
for i in range(0, len(sentence_embeddings), chunk_size):
chunk = sentence_embeddings[i:i + chunk_size]
chunk_scores = compute_similarity(hyde_vector, chunk)[0]
all_scores.extend(chunk_scores)
scores = np.array(all_scores)
top_k_idx = np.argsort(scores)[::-1]
results = [example['top_5000'][i] for i in top_k_idx]
top_k_sentences_urls = select_top_k(claim, results, args.top_k)
# print(f"Top {args.top_k} retrieved. Time elapsed: {time.time() - st:.2f}s")
json_data = {
"claim_id": example['claim_id'],
"claim": claim,
f"top_{args.top_k}": top_k_sentences_urls
}
output_json.write(json.dumps(json_data, ensure_ascii=False) + "\n")
output_json.flush()
except RuntimeError as e:
print(f"Error processing claim {example['claim_id']}: {e}")
continue
done += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--target_data", default="data_store/dev_retrieval_top_k.json")
parser.add_argument("--retrieved_top_k", type=int, default=5000)
parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("-o", "--json_output", type=str, default="data_store/dev_reranking_top_k.json")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("-s", "--start", type=int, default=0)
parser.add_argument("-e", "--end", type=int, default=-1)
args = parser.parse_args()
main(args)