JAM / utils.py
OrangeEye's picture
update Trust-Align
2023373
raw
history blame
11.7 kB
import datetime
import string
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
stop_words = stopwords.words('english')
import time
import arxiv
import colorlog
import torch
fmt_string = '%(log_color)s %(asctime)s - %(levelname)s - %(message)s'
log_colors = {
'DEBUG': 'white',
'INFO': 'green',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'purple'
}
colorlog.basicConfig(log_colors=log_colors, format=fmt_string, level=colorlog.INFO)
logger = colorlog.getLogger(__name__)
logger.setLevel(colorlog.INFO)
def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False):
if 'Semantic Search' in source:
title = rag_answer['document_metadata']['title'].replace('\n','')
#score = round(rag_answer['score'], 2)
date = rag_answer['document_metadata']['_time']
paper_abs = rag_answer['content']
authors = rag_answer['document_metadata']['authors'].replace('\n','')
doc_id = rag_answer['document_id']
paper_link = f'''https://arxiv.org/abs/{doc_id}'''
download_link = f'''https://arxiv.org/pdf/{doc_id}'''
elif 'Arxiv' in source:
title = rag_answer.title
date = rag_answer.updated.strftime('%d %b %Y')
paper_abs = rag_answer.summary.replace('\n',' ') + '\n'
authors = ', '.join([author.name for author in rag_answer.authors])
paper_link = rag_answer.links[0].href
download_link = rag_answer.links[1].href
else:
raise Exception
paper_title = f'''### {date} | [{title}]({paper_link}) | [⬇️]({download_link})\n'''
authors_formatted = f'*{authors}*' + ' \n\n'
md_text_formatted = paper_title + authors_formatted + paper_abs + '\n---------------\n'+ '\n'
if return_prompt_formatting:
doc = {
'title': title,
'text': paper_abs
}
return md_text_formatted, doc
return md_text_formatted
def remove_punctuation(text):
punct_str = string.punctuation
punct_str = punct_str.replace("'", "")
return text.translate(str.maketrans("", "", punct_str))
def remove_stopwords(text):
text = ' '.join(word for word in text.split(' ') if word not in stop_words)
return text
def search_cleaner(text):
new_text = text.lower()
new_text = remove_stopwords(new_text)
new_text = remove_punctuation(new_text)
return new_text
q = '(cat:cs.CV OR cat:cs.LG OR cat:cs.CL OR cat:cs.AI OR cat:cs.NE OR cat:cs.RO)'
def get_arxiv_live_search(query, client, max_results = 10):
clean_text = search_cleaner(query)
search = arxiv.Search(
query = clean_text + " AND "+q,
max_results = max_results,
sort_by = arxiv.SortCriterion.Relevance
)
results = client.results(search)
all_results = list(results)
return all_results
def make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=None):
# For doc prompt:
# - {ID}: doc id (starting from 1)
# - {T}: title
# - {P}: text
# use_shorter: None, "summary", or "extraction"
text = doc['text']
if use_shorter is not None:
text = doc[use_shorter]
return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1))
def get_shorter_text(item, docs, ndoc, key):
doc_list = []
for item_id, item in enumerate(docs):
if key not in item:
if len(doc_list) == 0:
# If there aren't any document, at least provide one (using full text)
item[key] = item['text']
doc_list.append(item)
logger.warn(f"No {key} found in document. It could be this data do not contain {key} or previous documents are not relevant. This is document {item_id}. This question will only have {len(doc_list)} documents.")
break
if "irrelevant" in item[key] or "Irrelevant" in item[key]:
continue
doc_list.append(item)
if len(doc_list) >= ndoc:
break
return doc_list
def make_demo(item, prompt, ndoc=None, doc_prompt=None, instruction=None, use_shorter=None, test=False):
# For demo prompt
# - {INST}: the instruction
# - {D}: the documents
# - {Q}: the question
# - {A}: the answers
# ndoc: number of documents to put in context
# use_shorter: None, "summary", or "extraction"
prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question'])
if "{D}" in prompt:
if ndoc == 0:
prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line
else:
doc_list = get_shorter_text(item, item["docs"], ndoc, use_shorter) if use_shorter is not None else item["docs"][:ndoc]
text = "".join([make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=use_shorter) for doc_id, doc in enumerate(doc_list)])
prompt = prompt.replace("{D}", text)
if not test:
answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"]
prompt = prompt.replace("{A}", "").rstrip() + answer
else:
prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n
return prompt
def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"):
from transformers import AutoTokenizer, AutoModelForCausalLM
dtype = torch.bfloat16
logger.info("loading llama_guard")
llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id)
llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="cuda")
# Get the id of the "unsafe" token, this will later be used to extract its probability
UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")
return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID
def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
# Skip the generation of whitespace.
# Now the next predicted token will be either "safe" or "unsafe"
prompt += "\n\n"
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=50,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
output_logits=True, # get logits
)
######
# Get generated text
######
# Number of tokens that correspond to the input prompt
input_length = inputs.input_ids.shape[1]
# Ignore the tokens from the input to get the tokens generated by the model
generated_token_ids = outputs.sequences[:, input_length:].cpu()
generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
######
# Get Probability of "unsafe" token
######
# First generated token is either "safe" or "unsafe".
# use the logits to calculate the probabilities.
first_token_logits = outputs.logits[0]
first_token_probs = torch.softmax(first_token_logits, dim=-1)
# From the probabilities of all tokens, extract the one for the "unsafe" token.
unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID]
unsafe_probability = unsafe_probability.item()
######
# Result
######
return {
"unsafe_score": unsafe_probability,
"generated_text": generated_text
}
def get_max_memory():
"""Get the maximum memory available for the current GPU for loading models."""
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{free_in_GB-1}GB'
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
return max_memory
def load_model(model_name_or_path, dtype=torch.bfloat16, int8=False):
# Load a huggingface model and tokenizer
# dtype: torch.float16 or torch.bfloat16
# int8: whether to use int8 quantization
# reserve_memory: how much memory to reserve for the model on each gpu (in GB)
# Load the FP16 model
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info(f"Loading {model_name_or_path} in {dtype}...")
if int8:
logger.warn("Use LLM.int8")
start_time = time.time()
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map='auto',
torch_dtype=dtype,
max_memory=get_max_memory(),
load_in_8bit=int8,
)
logger.info("Finish loading in %.2f sec." % (time.time() - start_time))
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
tokenizer.padding_side = "left"
return model, tokenizer
def load_vllm(model_name_or_path, dtype=torch.bfloat16):
from vllm import LLM, SamplingParams
logger.info(f"Loading {model_name_or_path} in {dtype}...")
start_time = time.time()
model = LLM(
model_name_or_path,
dtype=dtype,
gpu_memory_utilization=0.9,
max_seq_len_to_capture=2048,
max_model_len=8192,
)
sampling_params = SamplingParams(temperature=0.1, top_p=1.00, max_tokens=300)
logger.info("Finish loading in %.2f sec." % (time.time() - start_time))
# Load the tokenizer
tokenizer = model.get_tokenizer()
tokenizer.padding_side = "left"
return model, tokenizer, sampling_params
class LLM:
def __init__(self, model_name_or_path, use_vllm=True):
self.use_vllm = use_vllm
if use_vllm:
self.chat_llm, self.tokenizer, self.sampling_params = load_vllm(model_name_or_path)
else:
self.chat_llm, self.tokenizer = load_model(model_name_or_path)
self.prompt_exceed_max_length = 0
self.fewer_than_50 = 0
def generate(self, prompt, max_tokens=300, stop=None):
if max_tokens <= 0:
self.prompt_exceed_max_length += 1
logger.warning("Prompt exceeds max length and return an empty string as answer. If this happens too many times, it is suggested to make the prompt shorter")
return ""
if max_tokens < 50:
self.fewer_than_50 += 1
logger.warning("The model can at most generate < 50 tokens. If this happens too many times, it is suggested to make the prompt shorter")
if self.use_vllm:
inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False)
self.sampling_params.n = 1 # Number of output sequences to return for the given prompt
self.sampling_params.stop_token_ids = [self.chat_llm.llm_engine.get_model_config().hf_config.eos_token_id]
self.sampling_params.max_tokens = max_tokens
output = self.chat_llm.generate(
inputs,
self.sampling_params,
use_tqdm=True,
)
generation = output[0].outputs[0].text.strip()
else:
inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.chat_llm.device)
outputs = self.chat_llm.generate(
**inputs,
do_sample=True, temperature=0.1, top_p=1.0,
max_new_tokens=max_tokens,
num_return_sequences=1,
eos_token_id=[self.chat_llm.config.eos_token_id]
)
generation = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True).strip()
return generation