Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import gc | |
from transformers import AutoModel, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import json | |
import argparse | |
import time | |
from datetime import datetime, timedelta | |
import re | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
def encode_text(model, tokenizer, texts, batch_size=8, max_length=512): | |
"""Encode texts to embeddings using AutoModel""" | |
all_embeddings = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i + batch_size] | |
# Tokenize | |
encoded_input = tokenizer( | |
batch, | |
padding=True, | |
truncation=True, | |
max_length=max_length, | |
return_tensors='pt' | |
).to(model.device) | |
# Compute token embeddings | |
with torch.no_grad(): | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
model_output = model(**encoded_input) | |
# Use mean pooling | |
attention_mask = encoded_input['attention_mask'] | |
token_embeddings = model_output[0] # First element contains token embeddings | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
all_embeddings.append(embeddings.cpu().numpy()) | |
# Clear some memory | |
if i % (batch_size * 4) == 0: | |
torch.cuda.empty_cache() | |
gc.collect() | |
return np.vstack(all_embeddings) | |
def compute_similarity(emb1, emb2): | |
"""Compute cosine similarity between embeddings""" | |
return np.dot(emb1, emb2.T) / ( | |
np.linalg.norm(emb1, axis=1).reshape(-1, 1) * | |
np.linalg.norm(emb2, axis=1).reshape(1, -1) | |
) | |
def get_detailed_instruct(task_description: str, query: str) -> str: | |
return f'Instruct: {task_description}\nQuery: {query}' | |
def preprocess_sentences(sentence1, sentence2): | |
vectorizer = TfidfVectorizer().fit_transform([sentence1, sentence2]) | |
vectors = vectorizer.toarray() | |
cosine_sim = cosine_similarity(vectors) | |
similarity_score = cosine_sim[0][1] | |
return similarity_score | |
def remove_trailing_special_chars(text): | |
return re.sub(r'[\W_]+$', '', text) | |
def remove_special_chars_except_spaces(text): | |
return re.sub(r'[^\w\s]+', '', text) | |
def select_top_k(claim, results, top_k): | |
''' | |
remove sentence of similarity claim | |
''' | |
dup_check = set() | |
top_k_sentences_urls = [] | |
i = 0 | |
# print(results) | |
claim = remove_special_chars_except_spaces(claim).lower() | |
while len(top_k_sentences_urls) < top_k and i < len(results): | |
# print(i) | |
sentence = remove_special_chars_except_spaces(results[i]['sentence']).lower() | |
if sentence not in dup_check: | |
if preprocess_sentences(claim, sentence) > 0.97: | |
dup_check.add(sentence) | |
continue | |
if claim in sentence: | |
if len(claim) / len(sentence) > 0.92: | |
dup_check.add(sentence) | |
continue | |
top_k_sentences_urls.append({ | |
'sentence': results[i]['sentence'], | |
'url': results[i]['url']} | |
) | |
i += 1 | |
return top_k_sentences_urls | |
# def format_time(seconds): | |
# """Format time duration nicely.""" | |
# return str(timedelta(seconds=round(seconds))) | |
def compute_embeddings_batched(model, texts, batch_size=8): | |
"""Compute embeddings in smaller batches to manage memory""" | |
all_embeddings = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i:i + batch_size] | |
with torch.cuda.amp.autocast(dtype=torch.bfloat16): # Use bfloat16 | |
emb = model.encode(batch, batch_size=len(batch), show_progress_bar=False) | |
all_embeddings.append(emb) | |
# Clear some memory | |
if i % (batch_size * 4) == 0: | |
torch.cuda.empty_cache() | |
gc.collect() | |
return np.vstack(all_embeddings) | |
def main(args): | |
device = "cuda" if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {device}") | |
# Load model and tokenizer | |
model = AutoModel.from_pretrained( | |
"Salesforce/SFR-Embedding-2_R", | |
torch_dtype=torch.bfloat16, | |
low_cpu_mem_usage=True, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained("Salesforce/SFR-Embedding-2_R") | |
# Load target examples | |
target_examples = [] | |
with open(args.target_data, "r", encoding="utf-8") as json_file: | |
for i, line in enumerate(json_file): | |
try: | |
example = json.loads(r"{}".format(line)) | |
target_examples.append(example) | |
except: | |
print(f"CURRENT LINE broken {i}") | |
if args.end == -1: | |
args.end = len(target_examples) | |
files_to_process = list(range(args.start, args.end)) | |
total = len(files_to_process) | |
task = 'Given a web search query, retrieve relevant passages that answer the query' | |
with open(args.json_output, "w", encoding="utf-8") as output_json: | |
done = 0 | |
for idx, example in enumerate(target_examples): | |
if idx in files_to_process: | |
print(f"Processing claim {example['claim_id']}... Progress: {done + 1} / {total}") | |
claim = example['claim'] | |
query = [get_detailed_instruct(task, claim)] + [ | |
get_detailed_instruct(task, le) | |
for le in example['hypo_fc_docs'] | |
if len(le.strip()) > 0 | |
] | |
query_length = len(query) | |
sentences = [sent['sentence'] for sent in example[f'top_{5000}']][:args.retrieved_top_k] | |
# st = time.time() | |
try: | |
# Process query embeddings | |
query_embeddings = encode_text(model, tokenizer, query, batch_size=4) | |
avg_emb_q = np.mean(query_embeddings, axis=0) | |
hyde_vector = avg_emb_q.reshape((1, -1)) | |
# Process sentence embeddings in smaller chunks | |
sentence_embeddings = encode_text( | |
model, | |
tokenizer, | |
sentences, | |
batch_size=args.batch_size | |
) | |
# Compute similarities in chunks to save memory | |
chunk_size = 1000 | |
all_scores = [] | |
for i in range(0, len(sentence_embeddings), chunk_size): | |
chunk = sentence_embeddings[i:i + chunk_size] | |
chunk_scores = compute_similarity(hyde_vector, chunk)[0] | |
all_scores.extend(chunk_scores) | |
scores = np.array(all_scores) | |
top_k_idx = np.argsort(scores)[::-1] | |
results = [example['top_5000'][i] for i in top_k_idx] | |
top_k_sentences_urls = select_top_k(claim, results, args.top_k) | |
# print(f"Top {args.top_k} retrieved. Time elapsed: {time.time() - st:.2f}s") | |
json_data = { | |
"claim_id": example['claim_id'], | |
"claim": claim, | |
f"top_{args.top_k}": top_k_sentences_urls | |
} | |
output_json.write(json.dumps(json_data, ensure_ascii=False) + "\n") | |
output_json.flush() | |
except RuntimeError as e: | |
print(f"Error processing claim {example['claim_id']}: {e}") | |
continue | |
done += 1 | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--target_data", default="data_store/dev_retrieval_top_k.json") | |
parser.add_argument("--retrieved_top_k", type=int, default=5000) | |
parser.add_argument("--top_k", type=int, default=10) | |
parser.add_argument("-o", "--json_output", type=str, default="data_store/dev_reranking_top_k.json") | |
parser.add_argument("--batch_size", type=int, default=32) | |
parser.add_argument("-s", "--start", type=int, default=0) | |
parser.add_argument("-e", "--end", type=int, default=-1) | |
args = parser.parse_args() | |
main(args) |