File size: 11,859 Bytes
a8bbba9
2023373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6154a6f
2023373
 
 
 
 
 
a8bbba9
654e004
2023373
 
 
 
 
 
 
 
 
 
 
 
 
 
e6b01c8
6a2e657
 
2023373
 
 
 
1957ef6
2023373
 
 
 
4d0d167
2023373
6a2e657
2023373
 
 
 
 
 
 
 
 
 
 
 
 
6a2e657
2023373
6a2e657
2023373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0519e0
2023373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0519e0
2023373
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
import spaces
import datetime
import string

import nltk

nltk.download('stopwords')
from nltk.corpus import stopwords

stop_words = stopwords.words('english')
import time

import arxiv
import colorlog
import torch

fmt_string = '%(log_color)s %(asctime)s - %(levelname)s - %(message)s'
log_colors = {
        'DEBUG': 'white',
        'INFO': 'green',
        'WARNING': 'yellow',
        'ERROR': 'red',
        'CRITICAL': 'purple'
        }
colorlog.basicConfig(log_colors=log_colors, format=fmt_string, level=colorlog.INFO)
logger = colorlog.getLogger(__name__)
logger.setLevel(colorlog.INFO)



def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False):
    if 'Semantic Search' in source:
      title = rag_answer['document_metadata']['title'].replace('\n','')
      #score = round(rag_answer['score'], 2)
      date = rag_answer['document_metadata']['_time']  
      paper_abs = rag_answer['content']
      authors = rag_answer['document_metadata']['authors'].replace('\n','')
      doc_id = rag_answer['document_id']
      paper_link = f'''https://arxiv.org/abs/{doc_id}'''
      download_link = f'''https://arxiv.org/pdf/{doc_id}'''

    elif 'Arxiv' in source:
      title = rag_answer.title
      date = rag_answer.updated.strftime('%d %b %Y')
      paper_abs = rag_answer.summary.replace('\n',' ') + '\n'
      authors = ', '.join([author.name for author in rag_answer.authors])
      paper_link = rag_answer.links[0].href
      download_link = rag_answer.links[1].href
      
    else:
      raise Exception

    paper_title = f'''### {date} | [{title}]({paper_link}) | [⬇️]({download_link})\n'''
    authors_formatted = f'*{authors}*' + ' \n\n'

    md_text_formatted = paper_title + authors_formatted + paper_abs +  '\n---------------\n'+ '\n'
    if return_prompt_formatting:
        doc = {
            'title': title,
            'text': paper_abs
        }
        return md_text_formatted, doc

    return md_text_formatted

def remove_punctuation(text):
    punct_str = string.punctuation
    punct_str = punct_str.replace("'", "")
    return text.translate(str.maketrans("", "", punct_str))

def remove_stopwords(text):
    text = ' '.join(word for word in text.split(' ') if word not in stop_words)
    return text

def search_cleaner(text):
    new_text = text.lower()
    new_text = remove_stopwords(new_text)
    new_text = remove_punctuation(new_text)
    return new_text


q = '(cat:cs.CV OR cat:cs.LG OR cat:cs.CL OR cat:cs.AI OR cat:cs.NE OR cat:cs.RO)'


def get_arxiv_live_search(query, client, max_results = 10):
  clean_text = search_cleaner(query)
  search = arxiv.Search(
    query = clean_text + " AND "+q,
    max_results = max_results,
    sort_by = arxiv.SortCriterion.Relevance
  )
  results = client.results(search)
  all_results = list(results)
  return all_results


def make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=None):
    # For doc prompt:
    # - {ID}: doc id (starting from 1)
    # - {T}: title
    # - {P}: text
    # use_shorter: None, "summary", or "extraction"

    text = doc['text']
    if use_shorter is not None:
        text = doc[use_shorter]
    return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1))


def get_shorter_text(item, docs, ndoc, key):
    doc_list = []
    for item_id, item in enumerate(docs):
        if key not in item:
            if len(doc_list) == 0:
                # If there aren't any document, at least provide one (using full text)
                item[key] = item['text']
                doc_list.append(item)
            logger.warn(f"No {key} found in document. It could be this data do not contain {key} or previous documents are not relevant. This is document {item_id}. This question will only have {len(doc_list)} documents.")
            break
        if "irrelevant" in item[key] or "Irrelevant" in item[key]:
            continue
        doc_list.append(item)
        if len(doc_list) >= ndoc:
            break
    return doc_list


def make_demo(item, prompt, ndoc=None, doc_prompt=None, instruction=None, use_shorter=None, test=False):
    # For demo prompt
    # - {INST}: the instruction
    # - {D}: the documents
    # - {Q}: the question
    # - {A}: the answers
    # ndoc: number of documents to put in context
    # use_shorter: None, "summary", or "extraction"

    prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question'])
    if "{D}" in prompt:
        if ndoc == 0:
            prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line
        else:
            doc_list = get_shorter_text(item, item["docs"], ndoc, use_shorter) if use_shorter is not None else item["docs"][:ndoc]
            text = "".join([make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=use_shorter) for doc_id, doc in enumerate(doc_list)])
            prompt = prompt.replace("{D}", text)

    if not test:
        answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"]
        prompt = prompt.replace("{A}", "").rstrip() + answer
    else:
        prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n

    return prompt


def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"):
    from transformers import AutoTokenizer, AutoModelForCausalLM
    dtype = torch.bfloat16

    logger.info("loading llama_guard")
    llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id)
    llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to('cuda')

    # Get the id of the "unsafe" token, this will later be used to extract its probability
    UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")
    
    return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID


# @spaces.GPU(duration=120)
def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
    
    prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
    # Skip the generation of whitespace.
    # Now the next predicted token will be either "safe" or "unsafe"
    prompt += "\n\n"
    
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        return_dict_in_generate=True,
        pad_token_id=tokenizer.eos_token_id,
        output_logits=True,  # get logits
        do_sample=False,
        temperature=None,
        top_p=None
    )
    ######
    # Get generated text
    ######
    # logger.info(outputs)
    # Number of tokens that correspond to the input prompt
    input_length = inputs.input_ids.shape[1]
    # Ignore the tokens from the input to get the tokens generated by the model
    generated_token_ids = outputs.sequences[:, input_length:].cpu()
    logger.info(generated_token_ids)
    generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
    logger.info(generated_text)
    ######
    # Get Probability of "unsafe" token
    ######
   
    # First generated token is either "safe" or "unsafe". 
    # use the logits to calculate the probabilities.
    first_token_logits = outputs.logits[0]
    first_token_probs = torch.softmax(first_token_logits, dim=-1)
    
    # From the probabilities of all tokens, extract the one for the "unsafe" token.
    unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID]
    unsafe_probability = unsafe_probability.item()

    ########
    # Result
    ########
    return {
        "unsafe_score": unsafe_probability,
        "generated_text": generated_text
    }
    
    

def get_max_memory():
    """Get the maximum memory available for the current GPU for loading models."""
    free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
    max_memory = f'{free_in_GB-1}GB'
    n_gpus = torch.cuda.device_count()
    max_memory = {i: max_memory for i in range(n_gpus)}
    return max_memory


def load_model(model_name_or_path, dtype=torch.bfloat16, int8=False):
    # Load a huggingface model and tokenizer
    # dtype: torch.float16 or torch.bfloat16
    # int8: whether to use int8 quantization
    # reserve_memory: how much memory to reserve for the model on each gpu (in GB)

    # Load the FP16 model
    from transformers import AutoModelForCausalLM, AutoTokenizer
    logger.info(f"Loading {model_name_or_path} in {dtype}...")
    if int8:
        logger.warn("Use LLM.int8")
    start_time = time.time()
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        device_map='auto',
        torch_dtype=dtype,
        max_memory=get_max_memory(),
        load_in_8bit=int8,
    )
    logger.info("Finish loading in %.2f sec." % (time.time() - start_time))

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

    tokenizer.padding_side = "left"

    return model, tokenizer


def load_vllm(model_name_or_path, dtype=torch.bfloat16):
    from vllm import LLM, SamplingParams
    logger.info(f"Loading {model_name_or_path} in {dtype}...")
    start_time = time.time()
    model = LLM(
        model_name_or_path, 
        dtype=dtype,
        gpu_memory_utilization=0.9,
        max_seq_len_to_capture=2048,
        max_model_len=8192,
    )
    sampling_params = SamplingParams(temperature=0.1, top_p=0.95, max_tokens=300)
    logger.info("Finish loading in %.2f sec." % (time.time() - start_time))

    # Load the tokenizer
    tokenizer = model.get_tokenizer()
    
    tokenizer.padding_side = "left"
    
    return model, tokenizer, sampling_params



class LLM:

    def __init__(self, model_name_or_path, use_vllm=True):
        self.use_vllm = use_vllm
        if use_vllm:
            self.chat_llm, self.tokenizer, self.sampling_params = load_vllm(model_name_or_path)
        else:
            self.chat_llm, self.tokenizer = load_model(model_name_or_path)
        
        self.prompt_exceed_max_length = 0
        self.fewer_than_50 = 0
        
    def generate(self, prompt, max_tokens=300, stop=None):
        if max_tokens <= 0:
            self.prompt_exceed_max_length += 1
            logger.warning("Prompt exceeds max length and return an empty string as answer. If this happens too many times, it is suggested to make the prompt shorter")
            return ""
        if max_tokens < 50:
            self.fewer_than_50 += 1
            logger.warning("The model can at most generate < 50 tokens. If this happens too many times, it is suggested to make the prompt shorter")
        
        if self.use_vllm:
            inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False)
            self.sampling_params.n = 1  # Number of output sequences to return for the given prompt
            self.sampling_params.stop_token_ids = [self.chat_llm.llm_engine.get_model_config().hf_config.eos_token_id]
            self.sampling_params.max_tokens = max_tokens
            output = self.chat_llm.generate(
                inputs,
                self.sampling_params,
                use_tqdm=True,
            )
            generation = output[0].outputs[0].text.strip()

        else:
            inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.chat_llm.device)
            outputs = self.chat_llm.generate(
                **inputs,
                do_sample=True, temperature=0.1, top_p=0.95, 
                max_new_tokens=max_tokens,
                num_return_sequences=1,
                eos_token_id=[self.chat_llm.config.eos_token_id]
            )
            generation = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True).strip()
            
        return generation