import spaces import datetime import string import nltk nltk.download('stopwords') from nltk.corpus import stopwords stop_words = stopwords.words('english') import time import arxiv import colorlog import torch fmt_string = '%(log_color)s %(asctime)s - %(levelname)s - %(message)s' log_colors = { 'DEBUG': 'white', 'INFO': 'green', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'purple' } colorlog.basicConfig(log_colors=log_colors, format=fmt_string, level=colorlog.INFO) logger = colorlog.getLogger(__name__) logger.setLevel(colorlog.INFO) def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False): if 'Semantic Search' in source: title = rag_answer['document_metadata']['title'].replace('\n','') #score = round(rag_answer['score'], 2) date = rag_answer['document_metadata']['_time'] paper_abs = rag_answer['content'] authors = rag_answer['document_metadata']['authors'].replace('\n','') doc_id = rag_answer['document_id'] paper_link = f'''https://arxiv.org/abs/{doc_id}''' download_link = f'''https://arxiv.org/pdf/{doc_id}''' elif 'Arxiv' in source: title = rag_answer.title date = rag_answer.updated.strftime('%d %b %Y') paper_abs = rag_answer.summary.replace('\n',' ') + '\n' authors = ', '.join([author.name for author in rag_answer.authors]) paper_link = rag_answer.links[0].href download_link = rag_answer.links[1].href else: raise Exception paper_title = f'''### {date} | [{title}]({paper_link}) | [⬇️]({download_link})\n''' authors_formatted = f'*{authors}*' + ' \n\n' md_text_formatted = paper_title + authors_formatted + paper_abs + '\n---------------\n'+ '\n' if return_prompt_formatting: doc = { 'title': title, 'text': paper_abs } return md_text_formatted, doc return md_text_formatted def remove_punctuation(text): punct_str = string.punctuation punct_str = punct_str.replace("'", "") return text.translate(str.maketrans("", "", punct_str)) def remove_stopwords(text): text = ' '.join(word for word in text.split(' ') if word not in stop_words) return text def search_cleaner(text): new_text = text.lower() new_text = remove_stopwords(new_text) new_text = remove_punctuation(new_text) return new_text q = '(cat:cs.CV OR cat:cs.LG OR cat:cs.CL OR cat:cs.AI OR cat:cs.NE OR cat:cs.RO)' def get_arxiv_live_search(query, client, max_results = 10): clean_text = search_cleaner(query) search = arxiv.Search( query = clean_text + " AND "+q, max_results = max_results, sort_by = arxiv.SortCriterion.Relevance ) results = client.results(search) all_results = list(results) return all_results def make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=None): # For doc prompt: # - {ID}: doc id (starting from 1) # - {T}: title # - {P}: text # use_shorter: None, "summary", or "extraction" text = doc['text'] if use_shorter is not None: text = doc[use_shorter] return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1)) def get_shorter_text(item, docs, ndoc, key): doc_list = [] for item_id, item in enumerate(docs): if key not in item: if len(doc_list) == 0: # If there aren't any document, at least provide one (using full text) item[key] = item['text'] doc_list.append(item) logger.warn(f"No {key} found in document. It could be this data do not contain {key} or previous documents are not relevant. This is document {item_id}. This question will only have {len(doc_list)} documents.") break if "irrelevant" in item[key] or "Irrelevant" in item[key]: continue doc_list.append(item) if len(doc_list) >= ndoc: break return doc_list def make_demo(item, prompt, ndoc=None, doc_prompt=None, instruction=None, use_shorter=None, test=False): # For demo prompt # - {INST}: the instruction # - {D}: the documents # - {Q}: the question # - {A}: the answers # ndoc: number of documents to put in context # use_shorter: None, "summary", or "extraction" prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question']) if "{D}" in prompt: if ndoc == 0: prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line else: doc_list = get_shorter_text(item, item["docs"], ndoc, use_shorter) if use_shorter is not None else item["docs"][:ndoc] text = "".join([make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=use_shorter) for doc_id, doc in enumerate(doc_list)]) prompt = prompt.replace("{D}", text) if not test: answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"] prompt = prompt.replace("{A}", "").rstrip() + answer else: prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n return prompt def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"): from transformers import AutoTokenizer, AutoModelForCausalLM dtype = torch.bfloat16 logger.info("loading llama_guard") llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id) llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to('cuda') # Get the id of the "unsafe" token, this will later be used to extract its probability UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe") return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID @spaces.GPU(duration=120) def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID): prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False) # Skip the generation of whitespace. # Now the next predicted token will be either "safe" or "unsafe" prompt += "\n\n" inputs = tokenizer([prompt], return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=50, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id, output_logits=True, # get logits do_sample=False, temperature=None, top_p=None ) ###### # Get generated text ###### # logger.info(outputs) # Number of tokens that correspond to the input prompt input_length = inputs.input_ids.shape[1] # Ignore the tokens from the input to get the tokens generated by the model generated_token_ids = outputs.sequences[:, input_length:].cpu() logger.info(generated_token_ids) generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True) logger.info(generated_text) ###### # Get Probability of "unsafe" token ###### # First generated token is either "safe" or "unsafe". # use the logits to calculate the probabilities. first_token_logits = outputs.logits[0] first_token_probs = torch.softmax(first_token_logits, dim=-1) # From the probabilities of all tokens, extract the one for the "unsafe" token. unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID] unsafe_probability = unsafe_probability.item() ######## # Result ######## return { "unsafe_score": unsafe_probability, "generated_text": generated_text } def get_max_memory(): """Get the maximum memory available for the current GPU for loading models.""" free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) max_memory = f'{free_in_GB-1}GB' n_gpus = torch.cuda.device_count() max_memory = {i: max_memory for i in range(n_gpus)} return max_memory def load_model(model_name_or_path, dtype=torch.bfloat16, int8=False): # Load a huggingface model and tokenizer # dtype: torch.float16 or torch.bfloat16 # int8: whether to use int8 quantization # reserve_memory: how much memory to reserve for the model on each gpu (in GB) # Load the FP16 model from transformers import AutoModelForCausalLM, AutoTokenizer logger.info(f"Loading {model_name_or_path} in {dtype}...") if int8: logger.warn("Use LLM.int8") start_time = time.time() model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map='auto', torch_dtype=dtype, max_memory=get_max_memory(), load_in_8bit=int8, ) logger.info("Finish loading in %.2f sec." % (time.time() - start_time)) # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) tokenizer.padding_side = "left" return model, tokenizer def load_vllm(model_name_or_path, dtype=torch.bfloat16): from vllm import LLM, SamplingParams logger.info(f"Loading {model_name_or_path} in {dtype}...") start_time = time.time() model = LLM( model_name_or_path, dtype=dtype, gpu_memory_utilization=0.9, max_seq_len_to_capture=2048, max_model_len=8192, ) sampling_params = SamplingParams(temperature=0.1, top_p=0.95, max_tokens=300) logger.info("Finish loading in %.2f sec." % (time.time() - start_time)) # Load the tokenizer tokenizer = model.get_tokenizer() tokenizer.padding_side = "left" return model, tokenizer, sampling_params class LLM: def __init__(self, model_name_or_path, use_vllm=True): self.use_vllm = use_vllm if use_vllm: self.chat_llm, self.tokenizer, self.sampling_params = load_vllm(model_name_or_path) else: self.chat_llm, self.tokenizer = load_model(model_name_or_path) self.prompt_exceed_max_length = 0 self.fewer_than_50 = 0 def generate(self, prompt, max_tokens=300, stop=None): if max_tokens <= 0: self.prompt_exceed_max_length += 1 logger.warning("Prompt exceeds max length and return an empty string as answer. If this happens too many times, it is suggested to make the prompt shorter") return "" if max_tokens < 50: self.fewer_than_50 += 1 logger.warning("The model can at most generate < 50 tokens. If this happens too many times, it is suggested to make the prompt shorter") if self.use_vllm: inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False) self.sampling_params.n = 1 # Number of output sequences to return for the given prompt self.sampling_params.stop_token_ids = [self.chat_llm.llm_engine.get_model_config().hf_config.eos_token_id] self.sampling_params.max_tokens = max_tokens output = self.chat_llm.generate( inputs, self.sampling_params, use_tqdm=True, ) generation = output[0].outputs[0].text.strip() else: inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.chat_llm.device) outputs = self.chat_llm.generate( **inputs, do_sample=True, temperature=0.1, top_p=0.95, max_new_tokens=max_tokens, num_return_sequences=1, eos_token_id=[self.chat_llm.config.eos_token_id] ) generation = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True).strip() return generation