Spaces:
Running
on
Zero
Running
on
Zero
import datetime | |
import string | |
import nltk | |
nltk.download('stopwords') | |
from nltk.corpus import stopwords | |
stop_words = stopwords.words('english') | |
import time | |
import arxiv | |
import colorlog | |
import torch | |
fmt_string = '%(log_color)s %(asctime)s - %(levelname)s - %(message)s' | |
log_colors = { | |
'DEBUG': 'white', | |
'INFO': 'green', | |
'WARNING': 'yellow', | |
'ERROR': 'red', | |
'CRITICAL': 'purple' | |
} | |
colorlog.basicConfig(log_colors=log_colors, format=fmt_string, level=colorlog.INFO) | |
logger = colorlog.getLogger(__name__) | |
logger.setLevel(colorlog.INFO) | |
def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False): | |
if 'Semantic Search' in source: | |
title = rag_answer['document_metadata']['title'].replace('\n','') | |
#score = round(rag_answer['score'], 2) | |
date = rag_answer['document_metadata']['_time'] | |
paper_abs = rag_answer['content'] | |
authors = rag_answer['document_metadata']['authors'].replace('\n','') | |
doc_id = rag_answer['document_id'] | |
paper_link = f'''https://arxiv.org/abs/{doc_id}''' | |
download_link = f'''https://arxiv.org/pdf/{doc_id}''' | |
elif 'Arxiv' in source: | |
title = rag_answer.title | |
date = rag_answer.updated.strftime('%d %b %Y') | |
paper_abs = rag_answer.summary.replace('\n',' ') + '\n' | |
authors = ', '.join([author.name for author in rag_answer.authors]) | |
paper_link = rag_answer.links[0].href | |
download_link = rag_answer.links[1].href | |
else: | |
raise Exception | |
paper_title = f'''### {date} | [{title}]({paper_link}) | [⬇️]({download_link})\n''' | |
authors_formatted = f'*{authors}*' + ' \n\n' | |
md_text_formatted = paper_title + authors_formatted + paper_abs + '\n---------------\n'+ '\n' | |
if return_prompt_formatting: | |
doc = { | |
'title': title, | |
'text': paper_abs | |
} | |
return md_text_formatted, doc | |
return md_text_formatted | |
def remove_punctuation(text): | |
punct_str = string.punctuation | |
punct_str = punct_str.replace("'", "") | |
return text.translate(str.maketrans("", "", punct_str)) | |
def remove_stopwords(text): | |
text = ' '.join(word for word in text.split(' ') if word not in stop_words) | |
return text | |
def search_cleaner(text): | |
new_text = text.lower() | |
new_text = remove_stopwords(new_text) | |
new_text = remove_punctuation(new_text) | |
return new_text | |
q = '(cat:cs.CV OR cat:cs.LG OR cat:cs.CL OR cat:cs.AI OR cat:cs.NE OR cat:cs.RO)' | |
def get_arxiv_live_search(query, client, max_results = 10): | |
clean_text = search_cleaner(query) | |
search = arxiv.Search( | |
query = clean_text + " AND "+q, | |
max_results = max_results, | |
sort_by = arxiv.SortCriterion.Relevance | |
) | |
results = client.results(search) | |
all_results = list(results) | |
return all_results | |
def make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=None): | |
# For doc prompt: | |
# - {ID}: doc id (starting from 1) | |
# - {T}: title | |
# - {P}: text | |
# use_shorter: None, "summary", or "extraction" | |
text = doc['text'] | |
if use_shorter is not None: | |
text = doc[use_shorter] | |
return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1)) | |
def get_shorter_text(item, docs, ndoc, key): | |
doc_list = [] | |
for item_id, item in enumerate(docs): | |
if key not in item: | |
if len(doc_list) == 0: | |
# If there aren't any document, at least provide one (using full text) | |
item[key] = item['text'] | |
doc_list.append(item) | |
logger.warn(f"No {key} found in document. It could be this data do not contain {key} or previous documents are not relevant. This is document {item_id}. This question will only have {len(doc_list)} documents.") | |
break | |
if "irrelevant" in item[key] or "Irrelevant" in item[key]: | |
continue | |
doc_list.append(item) | |
if len(doc_list) >= ndoc: | |
break | |
return doc_list | |
def make_demo(item, prompt, ndoc=None, doc_prompt=None, instruction=None, use_shorter=None, test=False): | |
# For demo prompt | |
# - {INST}: the instruction | |
# - {D}: the documents | |
# - {Q}: the question | |
# - {A}: the answers | |
# ndoc: number of documents to put in context | |
# use_shorter: None, "summary", or "extraction" | |
prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question']) | |
if "{D}" in prompt: | |
if ndoc == 0: | |
prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line | |
else: | |
doc_list = get_shorter_text(item, item["docs"], ndoc, use_shorter) if use_shorter is not None else item["docs"][:ndoc] | |
text = "".join([make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=use_shorter) for doc_id, doc in enumerate(doc_list)]) | |
prompt = prompt.replace("{D}", text) | |
if not test: | |
answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"] | |
prompt = prompt.replace("{A}", "").rstrip() + answer | |
else: | |
prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n | |
return prompt | |
def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"): | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
dtype = torch.bfloat16 | |
logger.info("loading llama_guard") | |
llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="cuda") | |
# Get the id of the "unsafe" token, this will later be used to extract its probability | |
UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe") | |
return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID | |
def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID): | |
prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False) | |
# Skip the generation of whitespace. | |
# Now the next predicted token will be either "safe" or "unsafe" | |
prompt += "\n\n" | |
inputs = tokenizer([prompt], return_tensors="pt").to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
return_dict_in_generate=True, | |
pad_token_id=tokenizer.eos_token_id, | |
output_logits=True, # get logits | |
) | |
###### | |
# Get generated text | |
###### | |
# Number of tokens that correspond to the input prompt | |
input_length = inputs.input_ids.shape[1] | |
# Ignore the tokens from the input to get the tokens generated by the model | |
generated_token_ids = outputs.sequences[:, input_length:].cpu() | |
generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True) | |
###### | |
# Get Probability of "unsafe" token | |
###### | |
# First generated token is either "safe" or "unsafe". | |
# use the logits to calculate the probabilities. | |
first_token_logits = outputs.logits[0] | |
first_token_probs = torch.softmax(first_token_logits, dim=-1) | |
# From the probabilities of all tokens, extract the one for the "unsafe" token. | |
unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID] | |
unsafe_probability = unsafe_probability.item() | |
###### | |
# Result | |
###### | |
return { | |
"unsafe_score": unsafe_probability, | |
"generated_text": generated_text | |
} | |
def get_max_memory(): | |
"""Get the maximum memory available for the current GPU for loading models.""" | |
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) | |
max_memory = f'{free_in_GB-1}GB' | |
n_gpus = torch.cuda.device_count() | |
max_memory = {i: max_memory for i in range(n_gpus)} | |
return max_memory | |
def load_model(model_name_or_path, dtype=torch.bfloat16, int8=False): | |
# Load a huggingface model and tokenizer | |
# dtype: torch.float16 or torch.bfloat16 | |
# int8: whether to use int8 quantization | |
# reserve_memory: how much memory to reserve for the model on each gpu (in GB) | |
# Load the FP16 model | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
logger.info(f"Loading {model_name_or_path} in {dtype}...") | |
if int8: | |
logger.warn("Use LLM.int8") | |
start_time = time.time() | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
device_map='auto', | |
torch_dtype=dtype, | |
max_memory=get_max_memory(), | |
load_in_8bit=int8, | |
) | |
logger.info("Finish loading in %.2f sec." % (time.time() - start_time)) | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) | |
tokenizer.padding_side = "left" | |
return model, tokenizer | |
def load_vllm(model_name_or_path, dtype=torch.bfloat16): | |
from vllm import LLM, SamplingParams | |
logger.info(f"Loading {model_name_or_path} in {dtype}...") | |
start_time = time.time() | |
model = LLM( | |
model_name_or_path, | |
dtype=dtype, | |
gpu_memory_utilization=0.9, | |
max_seq_len_to_capture=2048, | |
max_model_len=8192, | |
) | |
sampling_params = SamplingParams(temperature=0.1, top_p=1.00, max_tokens=300) | |
logger.info("Finish loading in %.2f sec." % (time.time() - start_time)) | |
# Load the tokenizer | |
tokenizer = model.get_tokenizer() | |
tokenizer.padding_side = "left" | |
return model, tokenizer, sampling_params | |
class LLM: | |
def __init__(self, model_name_or_path, use_vllm=True): | |
self.use_vllm = use_vllm | |
if use_vllm: | |
self.chat_llm, self.tokenizer, self.sampling_params = load_vllm(model_name_or_path) | |
else: | |
self.chat_llm, self.tokenizer = load_model(model_name_or_path) | |
self.prompt_exceed_max_length = 0 | |
self.fewer_than_50 = 0 | |
def generate(self, prompt, max_tokens=300, stop=None): | |
if max_tokens <= 0: | |
self.prompt_exceed_max_length += 1 | |
logger.warning("Prompt exceeds max length and return an empty string as answer. If this happens too many times, it is suggested to make the prompt shorter") | |
return "" | |
if max_tokens < 50: | |
self.fewer_than_50 += 1 | |
logger.warning("The model can at most generate < 50 tokens. If this happens too many times, it is suggested to make the prompt shorter") | |
if self.use_vllm: | |
inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) | |
self.sampling_params.n = 1 # Number of output sequences to return for the given prompt | |
self.sampling_params.stop_token_ids = [self.chat_llm.llm_engine.get_model_config().hf_config.eos_token_id] | |
self.sampling_params.max_tokens = max_tokens | |
output = self.chat_llm.generate( | |
inputs, | |
self.sampling_params, | |
use_tqdm=True, | |
) | |
generation = output[0].outputs[0].text.strip() | |
else: | |
inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.chat_llm.device) | |
outputs = self.chat_llm.generate( | |
**inputs, | |
do_sample=True, temperature=0.1, top_p=1.0, | |
max_new_tokens=max_tokens, | |
num_return_sequences=1, | |
eos_token_id=[self.chat_llm.config.eos_token_id] | |
) | |
generation = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True).strip() | |
return generation |