Spaces:
Sleeping
Sleeping
# Suppress annoying warnings from this issue which cannot be solved: https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md and transformers packages | |
import warnings | |
warnings.filterwarnings("ignore") | |
import re | |
import torch | |
import torch.nn as nn | |
import traceback | |
from transformers import BartTokenizer, BartForConditionalGeneration | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import numpy as np | |
from nltk import sent_tokenize | |
import logging | |
import openai | |
from tqdm import tqdm | |
from sentence_transformers import SentenceTransformer, util | |
from openai.error import (APIError, RateLimitError, ServiceUnavailableError, | |
Timeout, APIConnectionError, InvalidRequestError) | |
from tenacity import (before_sleep_log, retry, retry_if_exception_type, | |
stop_after_delay, wait_random_exponential, stop_after_attempt) | |
from .utils import break_down2scenes | |
from .prompt import build_fact_prompt | |
from .openai_api import openai_api_response | |
logger = logging.getLogger(__name__) | |
class OpenAIEmbedding: | |
def __init__(self, api_key, model="text-embedding-3-large"): | |
self.api_key = api_key | |
self.model = model | |
openai.api_key = api_key | |
def encode(self, texts, **kwargs): | |
if isinstance(texts, str): | |
texts = [texts] | |
try: | |
response = openai.Embedding.create( | |
model=self.model, | |
input=texts, | |
) | |
# Extract embeddings from response | |
embeddings = [item["embedding"] for item in response["data"]] | |
return np.array(embeddings) | |
except Exception as e: | |
logger.error(f"Embedding API failed: {str(e)}") | |
return None | |
class NarrativeFactScore(): | |
def __init__(self, model="gpt-4o-mini", split_type="fast", checkpoint=None, api_key=None, model_id="gpt-4"): | |
self.sent_model = OpenAIEmbedding(api_key=api_key) | |
self.split_type = split_type | |
self.checkpoint = checkpoint | |
self.api_key = api_key | |
self.model_id = model_id | |
openai.api_key = api_key | |
if model == "gptscore": | |
self.metric = GPTScore(model=self.model_id, api_key=self.api_key) | |
self.metric_function = self.metric.gpt_score | |
else: | |
raise ValueError("NarrativeFactScore currently only supports GPTScore") | |
def get_surrounding_sentences(self, sentence_array, ii): | |
if ii > 0 and ii < len(sentence_array) - 1: | |
sents = " ".join(np.array(sentence_array)[ii - 1 : ii + 1]) | |
elif ii == 0: | |
sents = " ".join(np.array(sentence_array)[:2]) | |
elif ii == len(sentence_array) - 1: | |
sents = " ".join(np.array(sentence_array)[ii - 1 :]) | |
return sents | |
def group_into_sections(self, sentence_array, num_sent): | |
sectioned_sents = [] | |
for ii in range(0, len(sentence_array), num_sent): | |
sectioned_sents.append(" ".join(sentence_array)[ii : ii + num_sent]) | |
return sectioned_sents | |
def split_sent(self, text): | |
text_list = [] | |
if self.split_type == "fast": | |
for t in text.split('.'): | |
if len(t) == 0: | |
continue | |
text_list.append(t) | |
return text_list | |
elif self.split_type == "fast_comma": | |
for t in re.split(r'[.,]', text): | |
if len(t) == 0: | |
continue | |
text_list.append(t) | |
return text_list | |
elif self.split_type == "gpt": | |
prompt = build_fact_prompt( | |
prompt_template = './templates/atomic_fact.txt', | |
input_text_list=[text], | |
) | |
response = openai_api_response(prompt, model=self.model_id, api_key=self.api_key) | |
text_list = [] | |
for res in response.split('\n'): | |
text_list.append(res.strip()) | |
return text_list | |
else: | |
return None | |
def score_src_hyp_long(self, srcs, hyps, kgs): | |
all_scores = [] | |
all_scores_per_sent = [] | |
all_relevant_scenes = [] | |
all_summary_chunks = [] | |
all_feedback_list = [] | |
# src is a list containing source documents. | |
# hyps is a list containing predicted documents | |
total_score = 0 | |
for global_idx, (src, hyp) in enumerate(zip(tqdm(srcs), hyps)): | |
src_sents = break_down2scenes(src) | |
# Get embeddings using OpenAI API | |
sentence_embeddings_src = self.sent_model.encode(src_sents) | |
sentence_embeddings_kg = self.sent_model.encode(kgs) | |
doc_scores = [] | |
relevant_scenes = [] | |
feedbacks = [] | |
hyp_array = self.split_sent(hyp) | |
for idx, hyp_sentence in enumerate(hyp_array): | |
# Get embedding for hypothesis sentence | |
sentence_embeddings_hyp = self.sent_model.encode(hyp_sentence) | |
# Calculate cosine similarity | |
scores = util.cos_sim(sentence_embeddings_hyp, sentence_embeddings_src)[0] | |
scores_kg = util.cos_sim(sentence_embeddings_hyp, sentence_embeddings_kg)[0] | |
sorted_idxs = np.argsort(-1 * scores) # descending order | |
sorted_idxs_kg = np.argsort(-1 * scores_kg) # descending order | |
similar_src_sentences = [] | |
similar_src_sentences_kg = [] | |
triple = '' | |
for sorted_idx, ii in enumerate(sorted_idxs_kg[0:1]): | |
if sorted_idx == 0: | |
triple += f'{kgs[ii]}' | |
else: | |
triple += f', {kgs[ii]}' | |
for ii in sorted_idxs[0:1]: | |
similar_sents = src_sents[ii] | |
similar_src_sentences.append(similar_sents) | |
scores, feedback_list = self.metric_function(similar_src_sentences, [hyp_sentence for i in range(0, len(similar_src_sentences))], triple) | |
score = np.max(scores) | |
max_scene_idx = np.argmax(scores) | |
max_scene = similar_src_sentences[max_scene_idx] | |
feedback = feedback_list[max_scene_idx] | |
doc_scores.append(int(score)) | |
relevant_scenes.append(max_scene) | |
feedbacks.append(feedback) | |
doc_score = np.mean(doc_scores) | |
all_scores_per_sent.append(doc_scores) | |
all_scores.append(doc_score) | |
all_relevant_scenes.append(relevant_scenes) | |
all_summary_chunks.append(hyp_array) | |
all_feedback_list.append(feedbacks) | |
total_score += doc_score | |
if global_idx % 100 == 99: | |
print(f"Document mean {global_idx+1} Score: {total_score/(global_idx+1)} Score") | |
return all_scores, all_scores_per_sent, all_relevant_scenes, all_summary_chunks, all_feedback_list | |
class GPTScore(): | |
def __init__(self, model="gpt-4o", api_key=None, prompt='./templates/fact_score_kg.txt'): | |
self.max_length = 1024 | |
self.model = model | |
self.api_key = api_key | |
self.prompt = prompt | |
openai.api_key = api_key | |
def gpt_inference(self, prompt): | |
prompt_messages = [{"role": "user", "content": prompt}] | |
try: | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=prompt_messages, | |
temperature=0, | |
api_key=self.api_key | |
) | |
response = response.choices[0].message.content | |
except InvalidRequestError: | |
response = 1 | |
return response | |
def gpt_score(self, srcs, tgts, kgs, batch_size=4): | |
score_list = [] | |
feedback_list = [] | |
for i in range(len(srcs)): | |
src = srcs[i] | |
tgt = tgts[i] | |
prompt = build_fact_prompt( | |
prompt_template=self.prompt, | |
input_text_list=[src, kgs, tgt], | |
) | |
try: | |
score = self.gpt_inference(prompt) | |
if '1' in score: | |
score_list.append(float(1)) | |
feedback_list.append('') | |
else: | |
score_list.append(float(0)) | |
feedback_list.append(score) | |
except RuntimeError: | |
traceback.print_exc() | |
print(f"source: {src_list}") | |
print(f"target: {tgt_list}") | |
exit(0) | |
return score_list, feedback_list |