OrangeEye commited on
Commit
a5d6d73
·
0 Parent(s):

update Trust-Align

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ colbert/indexes/arxiv_colbert/collection.json filter=lfs diff=lfs merge=lfs -text
37
+ colbert/indexes/arxiv_colbert/docid_metadata_map.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Trust-Align
3
+ emoji: 🔥
4
+ colorFrom: blue
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: Measuring and Enhancing Trustworthiness of LLMs in RAG through Grounded Attributions and Learning to Refuse
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+ Index Last Updated : 2024-11-16
__pycache__/utils.cpython-310.pyc ADDED
Binary file (8.96 kB). View file
 
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # # Set CUDA device dynamically
3
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
4
+
5
+ import spaces
6
+ import torch
7
+ import transformers
8
+ import gradio as gr
9
+ from ragatouille import RAGPretrainedModel
10
+ from huggingface_hub import InferenceClient
11
+ import re
12
+ from datetime import datetime
13
+ import json
14
+ import os
15
+ import arxiv
16
+ from utils import get_md_text_abstract, search_cleaner, get_arxiv_live_search, make_demo, make_doc_prompt, load_llama_guard, moderate, LLM
17
+
18
+ global MODEL, CURRENT_MODEL
19
+ MODEL, CURRENT_MODEL = None, None
20
+
21
+ retrieve_results = 10
22
+ show_examples = True
23
+ llm_models_to_choose = ['Trust-Align-Qwen2.5', "meta-llama/Meta-Llama-3-8B-Instruct",'None']
24
+ llm_location_map={
25
+ "Trust-Align-Qwen2.5": os.getenv("MODEL_NAME"),
26
+ "meta-llama/Meta-Llama-3-8B-Instruct": "meta-llama/Meta-Llama-3-8B-Instruct", # "Qwen/Qwen2.5-7B-Instruct"
27
+ "None": None
28
+ }
29
+
30
+ generate_kwargs = dict(
31
+ temperature = 0.1,
32
+ max_new_tokens = 512,
33
+ top_p = 1.0,
34
+ do_sample = True,
35
+ )
36
+
37
+ # Load llama Guard
38
+ llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID = load_llama_guard("meta-llama/Llama-Guard-3-1B")
39
+
40
+ ## RAG MODEL
41
+ RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
42
+
43
+ try:
44
+ gr.Info("Setting up retriever, please wait...")
45
+ rag_initial_output = RAG.search("what is Mistral?", k = 1)
46
+ gr.Info("Retriever working successfully!")
47
+
48
+ except:
49
+ gr.Warning("Retriever not working!")
50
+
51
+ def choose_llm(choosed_llm):
52
+ global MODEL, CURRENT_MODEL
53
+ try:
54
+ gr.Info("Setting up LLM, please wait...")
55
+ MODEL = LLM(llm_location_map[choosed_llm], use_vllm=False)
56
+ CURRENT_MODEL = choosed_llm
57
+ gr.Info("LLM working successfully!")
58
+ except Exception as e:
59
+ raise RuntimeError("Failed to load the LLM MODEL.") from e
60
+
61
+ choose_llm(llm_models_to_choose[0])
62
+
63
+ # prompt used for generation
64
+ try:
65
+ with open("rejection_full.json") as f:
66
+ prompt_data = json.load(f)
67
+ except FileNotFoundError:
68
+ raise RuntimeError("Prompt data file 'rejection_full.json' not found.")
69
+ except json.JSONDecodeError:
70
+ raise RuntimeError("Failed to decode 'rejection_full.json'.")
71
+
72
+ ## Header
73
+ mark_text = '# 🔍 Search Results\n'
74
+ header_text = "# 🤖 Trust-Align: Measuring and Enhancing Trustworthiness of LLMs in RAG through Grounded Attributions and Learning to Refuse\n \n"
75
+
76
+ try:
77
+ with open("README.md", "r") as f:
78
+ mdfile = f.read()
79
+ date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
80
+ match = re.search(date_pattern, mdfile)
81
+ date = match.group().split(': ')[1]
82
+ formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
83
+ header_text += f'Index Last Updated: {formatted_date}\n'
84
+ index_info = f"Semantic Search - up to {formatted_date}"
85
+ except:
86
+ index_info = "Semantic Search"
87
+
88
+ database_choices = [index_info,'Arxiv Search - Latest - (EXPERIMENTAL)']
89
+
90
+ ## Arxiv API
91
+ arx_client = arxiv.Client()
92
+ is_arxiv_available = True
93
+ check_arxiv_result = get_arxiv_live_search("What is Mistral?", arx_client, retrieve_results)
94
+ if len(check_arxiv_result) == 0:
95
+ is_arxiv_available = False
96
+ print("Arxiv search not working, switching to default search ...")
97
+ database_choices = [index_info]
98
+
99
+
100
+
101
+ ## Show examples (disabled)
102
+ if show_examples:
103
+ with open("sample_outputs.json", "r") as f:
104
+ sample_outputs = json.load(f)
105
+ output_placeholder = sample_outputs['output_placeholder']
106
+ md_text_initial = sample_outputs['search_placeholder']
107
+
108
+ else:
109
+ output_placeholder = None
110
+ md_text_initial = ''
111
+
112
+
113
+ def rag_cleaner(inp):
114
+ rank = inp['rank']
115
+ title = inp['document_metadata']['title']
116
+ content = inp['content']
117
+ date = inp['document_metadata']['_time']
118
+ return f"{rank}. <b> {title} </b> \n Date : {date} \n Abstract: {content}"
119
+
120
+ def get_references(question, retriever, k = retrieve_results):
121
+ rag_out = retriever.search(query=question, k=k)
122
+ return rag_out
123
+
124
+ def get_rag(message):
125
+ return get_references(message, RAG)
126
+
127
+ with gr.Blocks(theme = gr.themes.Soft()) as demo:
128
+ header = gr.Markdown(header_text)
129
+
130
+ with gr.Group():
131
+ msg = gr.Textbox(label = 'Search', placeholder = 'What is Mistral?')
132
+
133
+ with gr.Accordion("Advanced Settings", open=False):
134
+ with gr.Row(equal_height = True):
135
+ llm_model = gr.Dropdown(choices = llm_models_to_choose, value = 'Trust-Align-Qwen2.5', label = 'LLM MODEL')
136
+ llm_results = gr.Slider(minimum=1, maximum=retrieve_results, value=3, step=1, interactive=True, label="Top n results as context")
137
+ database_src = gr.Dropdown(choices = database_choices, value = index_info, label = 'Search Source')
138
+ stream_results = gr.Checkbox(value = True, label = "Stream output", visible = False)
139
+
140
+ output_text = gr.Textbox(show_label = True, container = True, label = 'LLM Answer', visible = True, placeholder = output_placeholder)
141
+ input = gr.Textbox(visible=False) # placeholder
142
+ gr_md = gr.Markdown(mark_text + md_text_initial)
143
+
144
+ def update_with_rag_md(message, llm_results_use = 5, database_choice = index_info, llm_model_picked = 'Trust-Align-Qwen2.5'):
145
+ chat_round = [
146
+ {"role": "user",
147
+ "content": [
148
+ {"type": "text",
149
+ "text": message
150
+ }
151
+ ]
152
+ }
153
+ ]
154
+ # llama guard check for it
155
+ prompt_safety = moderate(chat_round, llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID)['generated_text']
156
+
157
+ if prompt_safety == "safe":
158
+ docs = []
159
+ database_to_use = database_choice
160
+ if database_choice == index_info:
161
+ rag_out = get_rag(message)
162
+ else:
163
+ arxiv_search_success = True
164
+ try:
165
+ rag_out = get_arxiv_live_search(message, arx_client, retrieve_results)
166
+ if len(rag_out) == 0:
167
+ arxiv_search_success = False
168
+ except:
169
+ arxiv_search_success = False
170
+
171
+ if not arxiv_search_success:
172
+ gr.Warning("Arxiv Search not working, switching to semantic search ...")
173
+ rag_out = get_rag(message)
174
+ database_to_use = index_info
175
+
176
+ md_text_updated = mark_text
177
+ for i in range(retrieve_results):
178
+ rag_answer = rag_out[i]
179
+ if i < llm_results_use:
180
+ md_text_paper, doc = get_md_text_abstract(rag_answer, source = database_to_use, return_prompt_formatting = True)
181
+ docs.append(doc)
182
+ md_text_paper = md_text_paper.strip("###")
183
+ md_text_updated += f"### [{i+1}] {md_text_paper}"
184
+ # else:
185
+ # md_text_paper = get_md_text_abstract(rag_answer, source = database_to_use)
186
+ # md_text_updated += md_text_paper
187
+
188
+ infer_item = {
189
+ "question": message,
190
+ "docs": docs,
191
+ }
192
+ prompt = make_demo(
193
+ infer_item,
194
+ prompt=prompt_data["demo_prompt"],
195
+ ndoc=llm_results_use,
196
+ doc_prompt=prompt_data["doc_prompt"],
197
+ instruction=prompt_data["instruction"],
198
+ test=True
199
+ )
200
+ else:
201
+ md_text_updated = mark_text + "### Invalid search query!"
202
+ prompt = ""
203
+
204
+ return md_text_updated, prompt
205
+
206
+
207
+ @spaces.GPU(duration=60)
208
+ def ask_llm(prompt, llm_model_picked = 'Trust-Align-Qwen2.5', stream_outputs = False):
209
+ model_disabled_text = "LLM MODEL is disabled"
210
+ output = ""
211
+
212
+ if llm_model_picked == 'None':
213
+ if stream_outputs:
214
+ for out in model_disabled_text:
215
+ output += out
216
+ yield output
217
+ return output
218
+ else:
219
+ return model_disabled_text
220
+
221
+ global MODEL
222
+ if llm_model_picked != CURRENT_MODEL:
223
+ del MODEL
224
+ import gc
225
+ gc.collect
226
+ torch.cuda.empty_cache()
227
+ choose_llm(llm_model_picked)
228
+
229
+ try:
230
+ stream = MODEL.generate(prompt, generate_kwargs["max_new_tokens"])
231
+ except:
232
+ gr.Warning("LLM Inference rate limit reached, try again later!")
233
+ return ""
234
+
235
+ if stream_outputs:
236
+ for response in stream:
237
+ output += response
238
+ yield output
239
+ return output
240
+ else:
241
+ return output
242
+
243
+
244
+ msg.submit(update_with_rag_md, [msg, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)
245
+
246
+ demo.queue().launch()
rejection_full.json ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "instruction": "Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents. If none of the provided documents contain the answer, only respond with \"I apologize, but I couldn't find an answer to your question in the search results.\". Then, add further explanation as to why an answer cannot be provided.",
3
+ "demo_sep": "\n\n\n",
4
+ "demo_prompt": "{INST}\n\nQuestion: {Q}\n\n{D}\nAnswer: {A}",
5
+ "doc_prompt": "Document [{ID}](Title: {T}): {P}\n",
6
+ "positive_demos": [
7
+ {
8
+ "question": "Which is the most rainy place on earth?",
9
+ "answer": "Several places on Earth claim to be the most rainy, such as Lloró, Colombia, which reported an average annual rainfall of 12,717 mm between 1952 and 1989, and López de Micay, Colombia, which reported an annual 12,892 mm between 1960 and 2012 [3]. However, the official record is held by Mawsynram, India with an average annual rainfall of 11,872 mm [3], although nearby town Sohra, India, also known as Cherrapunji, holds the record for most rain in a calendar month for July 1861 and most rain in a year from August 1860 to July 1861 [1].",
10
+ "docs": [
11
+ {
12
+ "title": "Cherrapunji",
13
+ "text": "Cherrapunji Cherrapunji (; with the native name Sohra being more commonly used, and can also be spelled Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in the Indian state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal chieftainship constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been credited as being the wettest place on Earth, but for now nearby Mawsynram currently holds that distinction. Cherrapunji still holds the all-time record for the most rainfall in a calendar month for July 1861 and most rain in a year from August 1860 to July 1861, however: it received in"
14
+ },
15
+ {
16
+ "title": "Cherrapunji",
17
+ "text": "Radio relay station known as Akashvani Cherrapunji. It broadcasts on FM frequencies. Cherrapunji Cherrapunji (; with the native name Sohra being more commonly used, and can also be spelled Cherrapunjee or Cherrapunji) is a subdivisional town in the East Khasi Hills district in the Indian state of Meghalaya. It is the traditional capital of aNongkhlaw \"hima\" (Khasi tribal chieftainship constituting a petty state), both known as Sohra or Churra. Cherrapunji has often been credited as being the wettest place on Earth, but for now nearby Mawsynram currently holds that distinction. Cherrapunji still holds the all-time record for the most rainfall"
18
+ },
19
+ {
20
+ "title": "Mawsynram",
21
+ "text": "Mawsynram Mawsynram () is a village in the East Khasi Hills district of Meghalaya state in north-eastern India, 65 kilometres from Shillong. Mawsynram receives one of the highest rainfalls in India. It is reportedly the wettest place on Earth, with an average annual rainfall of 11,872 mm, but that claim is disputed by Lloró, Colombia, which reported an average yearly rainfall of 12,717 mm between 1952 and 1989 and López de Micay, also in Colombia, which reported an annual 12,892 mm per year between 1960 and 2012. According to the \"Guinness Book of World Records\", Mawsynram received of rainfall in 1985. Mawsynram is located at 25° 18′"
22
+ },
23
+ {
24
+ "title": "Earth rainfall climatology",
25
+ "text": "Pacific Northwest, and the Sierra Nevada range are the wetter portions of the nation, with average rainfall exceeding per year. The drier areas are the Desert Southwest, Great Basin, valleys of northeast Arizona, eastern Utah, central Wyoming, eastern Oregon and Washington and the northeast of the Olympic Peninsula. The Big Bog on the island of Maui receives, on average, every year, making it the wettest location in the US, and all of Oceania. The annual average rainfall maxima across the continent lie across the northwest from northwest Brazil into northern Peru, Colombia, and Ecuador, then along the Atlantic coast of"
26
+ },
27
+ {
28
+ "title": "Going to Extremes",
29
+ "text": "in the world. Oymyakon in Siberia, where the average winter temperature is −47 °F (− 44 °C). Arica in Chile, where there had been fourteen consecutive years without rain. Fog is the only local source of water. Mawsynram in India, where average annual rainfall is 14 meters, falling within a four-month period in the monsoon season. The rainfall is approximately equal to that of its neighbor Cherrapunji. Dallol in Ethiopia, known as the 'Hell-hole of creation' where the temperature averages 94 °F (34 °C) over the year. In his second series, Middleton visited places without permanent towns, locations where \"survival\""
30
+ }
31
+ ]
32
+ },
33
+ {
34
+ "question": "When did the us break away from england?",
35
+ "answer": "The United States took the first step towards gaining independence from Great Britain when it declared independence from Great Britain on July 2, 1776 (although the event is now commemorated on July 4, 1776, the date when the Declaration of Independence was officially adopted by Congress) [2]. The Treaty of Paris was later signed on September 3, 1783, formally separating the United States from the British Empire [3].",
36
+ "docs": [
37
+ {
38
+ "title": "United States withdrawal from Saudi Arabia",
39
+ "text": "United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coordinating and training personnel from the U.S. Navy, U.S. Army and U.S. Marine Corps remained in Saudi Arabia under the aegis of Joint Task Force Southwest Asia (JTF-SWA), as part of Operation Southern Watch (OSW). The United Kingdom and France also maintained a small contingent of Royal Air Force and French Air Force"
40
+ },
41
+ {
42
+ "title": "Decolonization of the Americas",
43
+ "text": "and France has fully \"integrated\" most of its former colonies as fully constituent \"departments\" of France. The United States of America declared independence from Great Britain on July 2, 1776 (although the event is now commemorated on July 4, the date when the Declaration of Independence was officially adopted by Congress), in so doing becoming the first independent, foreign-recognized nation in the Americas and the first European colonial entity to break from its mother country. Britain formally acknowledged American independence in 1783 after its defeat in the American Revolutionary War. Although initially occupying only the land east of the Mississippi"
44
+ },
45
+ {
46
+ "title": "American Revolution",
47
+ "text": "second British army at Yorktown in the fall of 1781, effectively ending the war. The Treaty of Paris was signed September 3, 1783, formally ending the conflict and confirming the new nation's complete separation from the British Empire. The United States took possession of nearly all the territory east of the Mississippi River and south of the Great Lakes, with the British retaining control of Canada and Spain taking Florida. Among the significant results of the revolution was the creation of the United States Constitution, establishing a relatively strong federal national government that included an executive, a national judiciary, and"
48
+ },
49
+ {
50
+ "title": "Decolonization",
51
+ "text": "accelerate decolonialization and bring an end to the colonial empires of its Western allies, most importantly during the 1956 Suez Crisis, but American military bases were established around the world and direct and indirect interventions continued in Korea, Indochina, Latin America (\"inter alia\", the 1965 occupation of the Dominican Republic), Africa, and the Middle East to oppose Communist invasions and insurgencies. Since the dissolution of the Soviet Union, the United States has been far less active in the Americas, but invaded Afghanistan and Iraq following the September 11 attacks in 2001, establishing army and air bases in Central Asia. Before"
52
+ },
53
+ {
54
+ "title": "Decolonization",
55
+ "text": "the responsibility of the United Kingdom (with a copy of the new constitution annexed), and finally, if approved, issuance of an Order of Council fixing the exact date of independence. After World War I, several former German and Ottoman territories in the Middle East, Africa, and the Pacific were governed by the UK as League of Nations mandates. Some were administered directly by the UK, and others by British dominions – Nauru and the Territory of New Guinea by Australia, South West Africa by the Union of South Africa, and Western Samoa by New Zealand. Egypt became independent in 1922,"
56
+ }
57
+ ]
58
+ }
59
+ ],
60
+ "reject_demos": [
61
+ {
62
+ "question": "Who set the record for longest field goal?",
63
+ "answer": "I apologize, but I couldn't find an answer to your question in the search results.",
64
+ "docs": [
65
+ {
66
+ "id": "5758609",
67
+ "title": "Sebastian Janikowski",
68
+ "text": "have broken the all-time NFL field goal record of 63 yards. However, it bounced off the right upright and came back out. On September 28, 2008, Janikowski unsuccessfully attempted a 76-yard field goal against the San Diego Chargers into the heavy wind right before halftime. This is presumed to be the longest attempt in NFL history; though the league keeps no such records on attempts, the longest known attempts previous to this were 74 yard attempts by Mark Moseley and Joe Danelo in 1979. On October 19, 2008, Janikowski broke his own Raiders team record, making a 57-yard field goal",
69
+ "score": 0.78466796875,
70
+ "answers_found": [
71
+ 0,
72
+ 0
73
+ ],
74
+ "rec_score": 0.0
75
+ },
76
+ {
77
+ "id": "12183799",
78
+ "title": "Dirk Borgognone",
79
+ "text": "Dirk Borgognone Dirk Ronald Borgognone (born January 9, 1968) is a former National Football League placekicker who currently holds the record for the longest field goal ever kicked in the history of high school football, 68 yards. Borgognone attended Reno High School, initially playing as a soccer player. He soon switched to football and was trained in a \"straight-on\" kicking style. On September 27, 1985, he kicked the longest field goal in high school football history, during a Reno High School game at Sparks High School. The kick measured 68 yards and was longer than any that had ever been",
80
+ "score": 0.7822265625,
81
+ "answers_found": [
82
+ 0,
83
+ 0
84
+ ],
85
+ "rec_score": 0.0
86
+ },
87
+ {
88
+ "id": "6048593",
89
+ "title": "Russell Erxleben",
90
+ "text": "against the University of Colorado. Despite early troubles in the game, including a missed field goal and a blocked extra point, Erxleben kicked a field goal late in the game, breaking the tie and securing the win for Texas. In 1977, in a game against Rice University, he set the record for the longest field goal in NCAA history with a 67-yard kick. UT head coach Fred Akers said of the kick, \"It was like a gunshot. We couldn't believe a ball was going that far. It had another eight yards on it.\" Erxleben kicked two other field goals over",
91
+ "score": 0.7705078125,
92
+ "answers_found": [
93
+ 0,
94
+ 0
95
+ ],
96
+ "rec_score": 0.0
97
+ },
98
+ {
99
+ "id": "9303115",
100
+ "title": "Field goal",
101
+ "text": "to Dempsey's 1970 kick, the longest field goal in NFL history was 56 yards, by Bert Rechichar of the Baltimore Colts A 55-yard field goal, achieved by a drop kick, was recorded by Paddy Driscoll in , and stood as the unofficial record until that point; some sources indicate a 54-yarder by Glenn Presnell in as the due to the inability to precisely verify Driscoll's kick. In a pre-season NFL game between the Denver Broncos and the Seattle Seahawks on August 29, 2002, Ola Kimrin kicked a 65-yard field goal. However, because pre-season games are not counted toward official records,",
102
+ "score": 0.76513671875,
103
+ "answers_found": [
104
+ 0,
105
+ 0
106
+ ],
107
+ "rec_score": 0.0
108
+ },
109
+ {
110
+ "id": "4853018",
111
+ "title": "Steve Christie",
112
+ "text": "Canton. He then went on to kick five field goals in the Bills 29-10 win over the Miami Dolphins in the AFC championship game, helping Buffalo get to their third consecutive Super Bowl. In 1993, Christie set a Bills record by kicking a 59-yard field goal in a regular season game. It was only four yards short of the all-time NFL record. In Super Bowl XXVIII, Christie set a Super Bowl record by kicking a 54-yard field goal. It is currently the longest field goal ever made in Super Bowl History. In the 2000 season, Christie was an instrumental part",
113
+ "score": 0.75341796875,
114
+ "answers_found": [
115
+ 0,
116
+ 0
117
+ ],
118
+ "rec_score": 0.0
119
+ }
120
+ ]
121
+ },
122
+ {
123
+ "question": "Who played galen in planet of the apes?",
124
+ "answer": "I apologize, but I couldn't find an answer to your question in the search results.",
125
+ "docs": [
126
+ {
127
+ "id": "12677620",
128
+ "title": "Planet of the Apes (2001 film)",
129
+ "text": "be the Lincoln Memorial, only to find that it is now a monument to General Thade. A swarm of police officers, firefighters, and news reporters descend on Leo, all of whom are apes. Small roles include David Warner (Senator Sandar), Lisa Marie (Nova), Erick Avari (Tival), Luke Eberl (Birn), Evan Parke (Gunnar), Glenn Shadix (Senator Nado), Freda Foh Shen (Bon), Chris Ellis (Lt. Gen. Karl Vasich) and Anne Ramsay (Lt. Col. Grace Alexander). There are also cameo appearances by Charlton Heston (uncredited) as Zaius, Thade's father, and Linda Harrison (the woman in the cart). Both participated in two original films",
130
+ "score": 0.7529296875,
131
+ "answers_found": [
132
+ 0,
133
+ 0
134
+ ],
135
+ "rec_score": 0.0
136
+ },
137
+ {
138
+ "id": "3943319",
139
+ "title": "Severn Darden",
140
+ "text": "Severn Darden Severn Teakle Darden Jr. (November 9, 1929 \u2013 May 27, 1995) was an American comedian and actor, and an original member of The Second City Chicago-based comedy troupe as well as its predecessor, the Compass Players. He is perhaps best known from his film appearances for playing the human leader Kolp in the fourth and fifth \"Planet of the Apes\" films. Born in New Orleans, Louisiana, he attended the University of Chicago. Darden\u2019s offbeat and intellectual sense of humor, appropriate for someone who attended the University of Chicago and in fact a major element in the style of",
141
+ "score": 0.74267578125,
142
+ "answers_found": [
143
+ 0,
144
+ 0
145
+ ],
146
+ "rec_score": 0.0
147
+ },
148
+ {
149
+ "id": "13813715",
150
+ "title": "Planet of the Apes",
151
+ "text": "film stars Mark Wahlberg as astronaut Leo Davidson, who accidentally travels through a wormhole to a distant planet where talking apes enslave humans. He leads a human revolt and upends ape civilization by discovering that the apes evolved from the normal earth primates who had accompanied his mission, and arrived years before. Helena Bonham Carter played chimpanzee Ari, while Tim Roth played the human-hating chimpanzee General Thade. The film received mixed reviews; most critics believed it failed to compare to the original. Much of the negative commentary focused on the confusing plot and twist ending, though many reviewers praised the",
152
+ "score": 0.74169921875,
153
+ "answers_found": [
154
+ 0,
155
+ 0
156
+ ],
157
+ "rec_score": 0.0
158
+ },
159
+ {
160
+ "id": "3386349",
161
+ "title": "Maurice Evans (actor)",
162
+ "text": "Maurice Evans (actor) Maurice Herbert Evans (3 June 1901 \u2013 12 March 1989) was a British actor, noted for his interpretations of Shakespearean characters. His best-known screen roles are Dr. Zaius in the 1968 film \"Planet of the Apes\" and as Samantha Stephens's father, Maurice, on \"Bewitched\". Evans was born at 28 Icen Way (where there is now a memorial plaque, unveiled in 2013 by Tegen Evans, his great-great niece) in Dorchester, Dorset. He was the son of Laura (Turner) and Alfred Herbert Evans, a Welsh dispensing chemist and keen amateur actor who made adaptations of novels by Thomas Hardy",
163
+ "score": 0.734375,
164
+ "answers_found": [
165
+ 0,
166
+ 0
167
+ ],
168
+ "rec_score": 0.0
169
+ },
170
+ {
171
+ "id": "823444",
172
+ "title": "Ricardo Montalba\u0301n",
173
+ "text": "was played by Andy Garc\u00eda. Ricardo Montalb\u00e1n Ricardo Gonzalo Pedro Montalb\u00e1n y Merino, (; ; November 25, 1920 \u2013 January 14, 2009) was a Mexican actor. His career spanned seven decades, during which he became known for many different performances in a variety of genres, from crime and drama to musicals and comedy. Among his notable roles was Armando in the \"Planet of the Apes\" film series from the early 1970s, where he starred in \"Escape from the Planet of the Apes\" (1971) and \"Conquest of the Planet of the Apes\" (1972). Ricardo Montalb\u00e1n played Mr. Roarke on the television",
174
+ "score": 0.7314453125,
175
+ "answers_found": [
176
+ 0,
177
+ 0
178
+ ],
179
+ "rec_score": 0.0
180
+ }
181
+ ]
182
+ }
183
+ ]
184
+ }
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.5.0
2
+ spaces==0.30.3
3
+ PyMuPDF==1.24.14
4
+ llama-index==0.12.1
5
+ llama-index-vector-stores-faiss==0.3.0
6
+ chromadb==0.5.20
7
+ llama-index-vector-stores-chroma==0.4.0
8
+ llama-index-embeddings-huggingface==0.4.0
9
+ vllm==0.6.2
10
+ sentence-transformers==2.7.0
11
+ arxiv
12
+ ragatouille
13
+ hf_transfer
14
+ colorlog
15
+ accelerate==1.1.1
sample_outputs.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"search_placeholder": "### 10 Oct 2023 | [Mistral 7B](https://arxiv.org/abs/2310.06825) | [\u2b07\ufe0f](https://arxiv.org/pdf/2310.06825)\n*Albert Q. Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lucile Saulnier, L'elio Renard Lavaud, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timoth'ee Lacroix, William El Sayed* \n\nWe introduce Mistral 7B v0.1, a 7-billion-parameter language model engineered\nfor superior performance and efficiency. Mistral 7B outperforms Llama 2 13B\nacross all evaluated benchmarks, and Llama 1 34B in reasoning, mathematics, and\ncode generation. Our model leverages grouped-query attention (GQA) for faster\ninference, coupled with sliding window attention (SWA) to effectively handle\nsequences of arbitrary length with a reduced inference cost. We also provide a\nmodel fine-tuned to follow instructions, Mistral 7B -- Instruct, that surpasses\nthe Llama 2 13B -- Chat model both on human and automated benchmarks. Our\nmodels are released under the Apache 2.0 license.\n", "output_placeholder": "Mistral is a 7-billion-parameter language model engineered for superior performance and efficiency. It was introduced in the paper \"Mistral 7B: A Superior Large Language Model\" [1]. Mistral outperforms other language models like Llama 2 13B and Llama 1 34B in various benchmarks, including reasoning, mathematics, and code generation. The model uses grouped-query attention (GQA) and sliding window attention (SWA) for faster inference and handling sequences of arbitrary length with reduced inference cost. Additionally, a fine-tuned version of Mistral, Mistral 7B -- Instruct, was released, which surpasses Llama 2 13B -- Chat model on human and automated benchmarks [1]. \n[1] Mistral 7B: A Superior Large Language Model. (2023). Retrieved from https://arxiv.org/abs/2303.14311."}
utils.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import string
3
+
4
+ import nltk
5
+
6
+ nltk.download('stopwords')
7
+ from nltk.corpus import stopwords
8
+
9
+ stop_words = stopwords.words('english')
10
+ import time
11
+
12
+ import arxiv
13
+ import colorlog
14
+ import torch
15
+
16
+ fmt_string = '%(log_color)s %(asctime)s - %(levelname)s - %(message)s'
17
+ log_colors = {
18
+ 'DEBUG': 'white',
19
+ 'INFO': 'green',
20
+ 'WARNING': 'yellow',
21
+ 'ERROR': 'red',
22
+ 'CRITICAL': 'purple'
23
+ }
24
+ colorlog.basicConfig(log_colors=log_colors, format=fmt_string, level=colorlog.INFO)
25
+ logger = colorlog.getLogger(__name__)
26
+ logger.setLevel(colorlog.INFO)
27
+
28
+
29
+
30
+ def get_md_text_abstract(rag_answer, source = ['Arxiv Search', 'Semantic Search'][1], return_prompt_formatting = False):
31
+ if 'Semantic Search' in source:
32
+ title = rag_answer['document_metadata']['title'].replace('\n','')
33
+ #score = round(rag_answer['score'], 2)
34
+ date = rag_answer['document_metadata']['_time']
35
+ paper_abs = rag_answer['content']
36
+ authors = rag_answer['document_metadata']['authors'].replace('\n','')
37
+ doc_id = rag_answer['document_id']
38
+ paper_link = f'''https://arxiv.org/abs/{doc_id}'''
39
+ download_link = f'''https://arxiv.org/pdf/{doc_id}'''
40
+
41
+ elif 'Arxiv' in source:
42
+ title = rag_answer.title
43
+ date = rag_answer.updated.strftime('%d %b %Y')
44
+ paper_abs = rag_answer.summary.replace('\n',' ') + '\n'
45
+ authors = ', '.join([author.name for author in rag_answer.authors])
46
+ paper_link = rag_answer.links[0].href
47
+ download_link = rag_answer.links[1].href
48
+
49
+ else:
50
+ raise Exception
51
+
52
+ paper_title = f'''### {date} | [{title}]({paper_link}) | [⬇️]({download_link})\n'''
53
+ authors_formatted = f'*{authors}*' + ' \n\n'
54
+
55
+ md_text_formatted = paper_title + authors_formatted + paper_abs + '\n---------------\n'+ '\n'
56
+ if return_prompt_formatting:
57
+ doc = {
58
+ 'title': title,
59
+ 'text': paper_abs
60
+ }
61
+ return md_text_formatted, doc
62
+
63
+ return md_text_formatted
64
+
65
+ def remove_punctuation(text):
66
+ punct_str = string.punctuation
67
+ punct_str = punct_str.replace("'", "")
68
+ return text.translate(str.maketrans("", "", punct_str))
69
+
70
+ def remove_stopwords(text):
71
+ text = ' '.join(word for word in text.split(' ') if word not in stop_words)
72
+ return text
73
+
74
+ def search_cleaner(text):
75
+ new_text = text.lower()
76
+ new_text = remove_stopwords(new_text)
77
+ new_text = remove_punctuation(new_text)
78
+ return new_text
79
+
80
+
81
+ q = '(cat:cs.CV OR cat:cs.LG OR cat:cs.CL OR cat:cs.AI OR cat:cs.NE OR cat:cs.RO)'
82
+
83
+
84
+ def get_arxiv_live_search(query, client, max_results = 10):
85
+ clean_text = search_cleaner(query)
86
+ search = arxiv.Search(
87
+ query = clean_text + " AND "+q,
88
+ max_results = max_results,
89
+ sort_by = arxiv.SortCriterion.Relevance
90
+ )
91
+ results = client.results(search)
92
+ all_results = list(results)
93
+ return all_results
94
+
95
+
96
+ def make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=None):
97
+ # For doc prompt:
98
+ # - {ID}: doc id (starting from 1)
99
+ # - {T}: title
100
+ # - {P}: text
101
+ # use_shorter: None, "summary", or "extraction"
102
+
103
+ text = doc['text']
104
+ if use_shorter is not None:
105
+ text = doc[use_shorter]
106
+ return doc_prompt.replace("{T}", doc["title"]).replace("{P}", text).replace("{ID}", str(doc_id+1))
107
+
108
+
109
+ def get_shorter_text(item, docs, ndoc, key):
110
+ doc_list = []
111
+ for item_id, item in enumerate(docs):
112
+ if key not in item:
113
+ if len(doc_list) == 0:
114
+ # If there aren't any document, at least provide one (using full text)
115
+ item[key] = item['text']
116
+ doc_list.append(item)
117
+ logger.warn(f"No {key} found in document. It could be this data do not contain {key} or previous documents are not relevant. This is document {item_id}. This question will only have {len(doc_list)} documents.")
118
+ break
119
+ if "irrelevant" in item[key] or "Irrelevant" in item[key]:
120
+ continue
121
+ doc_list.append(item)
122
+ if len(doc_list) >= ndoc:
123
+ break
124
+ return doc_list
125
+
126
+
127
+ def make_demo(item, prompt, ndoc=None, doc_prompt=None, instruction=None, use_shorter=None, test=False):
128
+ # For demo prompt
129
+ # - {INST}: the instruction
130
+ # - {D}: the documents
131
+ # - {Q}: the question
132
+ # - {A}: the answers
133
+ # ndoc: number of documents to put in context
134
+ # use_shorter: None, "summary", or "extraction"
135
+
136
+ prompt = prompt.replace("{INST}", instruction).replace("{Q}", item['question'])
137
+ if "{D}" in prompt:
138
+ if ndoc == 0:
139
+ prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line
140
+ else:
141
+ doc_list = get_shorter_text(item, item["docs"], ndoc, use_shorter) if use_shorter is not None else item["docs"][:ndoc]
142
+ text = "".join([make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=use_shorter) for doc_id, doc in enumerate(doc_list)])
143
+ prompt = prompt.replace("{D}", text)
144
+
145
+ if not test:
146
+ answer = "\n" + "\n".join(item["answer"]) if isinstance(item["answer"], list) else item["answer"]
147
+ prompt = prompt.replace("{A}", "").rstrip() + answer
148
+ else:
149
+ prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n
150
+
151
+ return prompt
152
+
153
+
154
+ def load_llama_guard(model_id = "meta-llama/Llama-Guard-3-1B"):
155
+ from transformers import AutoTokenizer, AutoModelForCausalLM
156
+ dtype = torch.bfloat16
157
+
158
+ logger.info("loading llama_guard")
159
+ llama_guard_tokenizer = AutoTokenizer.from_pretrained(model_id)
160
+ llama_guard = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map="cuda")
161
+
162
+ # Get the id of the "unsafe" token, this will later be used to extract its probability
163
+ UNSAFE_TOKEN_ID = llama_guard_tokenizer.convert_tokens_to_ids("unsafe")
164
+
165
+ return llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID
166
+
167
+
168
+ def moderate(chat, model, tokenizer, UNSAFE_TOKEN_ID):
169
+
170
+ prompt = tokenizer.apply_chat_template(chat, return_tensors="pt", tokenize=False)
171
+ # Skip the generation of whitespace.
172
+ # Now the next predicted token will be either "safe" or "unsafe"
173
+ prompt += "\n\n"
174
+
175
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
176
+ outputs = model.generate(
177
+ **inputs,
178
+ max_new_tokens=50,
179
+ return_dict_in_generate=True,
180
+ pad_token_id=tokenizer.eos_token_id,
181
+ output_logits=True, # get logits
182
+ )
183
+ ######
184
+ # Get generated text
185
+ ######
186
+
187
+ # Number of tokens that correspond to the input prompt
188
+ input_length = inputs.input_ids.shape[1]
189
+ # Ignore the tokens from the input to get the tokens generated by the model
190
+ generated_token_ids = outputs.sequences[:, input_length:].cpu()
191
+ generated_text = tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
192
+
193
+ ######
194
+ # Get Probability of "unsafe" token
195
+ ######
196
+
197
+ # First generated token is either "safe" or "unsafe".
198
+ # use the logits to calculate the probabilities.
199
+ first_token_logits = outputs.logits[0]
200
+ first_token_probs = torch.softmax(first_token_logits, dim=-1)
201
+
202
+ # From the probabilities of all tokens, extract the one for the "unsafe" token.
203
+ unsafe_probability = first_token_probs[0, UNSAFE_TOKEN_ID]
204
+ unsafe_probability = unsafe_probability.item()
205
+
206
+ ######
207
+ # Result
208
+ ######
209
+ return {
210
+ "unsafe_score": unsafe_probability,
211
+ "generated_text": generated_text
212
+ }
213
+
214
+
215
+
216
+ def get_max_memory():
217
+ """Get the maximum memory available for the current GPU for loading models."""
218
+ free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
219
+ max_memory = f'{free_in_GB-1}GB'
220
+ n_gpus = torch.cuda.device_count()
221
+ max_memory = {i: max_memory for i in range(n_gpus)}
222
+ return max_memory
223
+
224
+
225
+ def load_model(model_name_or_path, dtype=torch.bfloat16, int8=False):
226
+ # Load a huggingface model and tokenizer
227
+ # dtype: torch.float16 or torch.bfloat16
228
+ # int8: whether to use int8 quantization
229
+ # reserve_memory: how much memory to reserve for the model on each gpu (in GB)
230
+
231
+ # Load the FP16 model
232
+ from transformers import AutoModelForCausalLM, AutoTokenizer
233
+ logger.info(f"Loading {model_name_or_path} in {dtype}...")
234
+ if int8:
235
+ logger.warn("Use LLM.int8")
236
+ start_time = time.time()
237
+ model = AutoModelForCausalLM.from_pretrained(
238
+ model_name_or_path,
239
+ device_map='auto',
240
+ torch_dtype=dtype,
241
+ max_memory=get_max_memory(),
242
+ load_in_8bit=int8,
243
+ )
244
+ logger.info("Finish loading in %.2f sec." % (time.time() - start_time))
245
+
246
+ # Load the tokenizer
247
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
248
+
249
+ tokenizer.padding_side = "left"
250
+
251
+ return model, tokenizer
252
+
253
+
254
+ def load_vllm(model_name_or_path, dtype=torch.bfloat16):
255
+ from vllm import LLM, SamplingParams
256
+ logger.info(f"Loading {model_name_or_path} in {dtype}...")
257
+ start_time = time.time()
258
+ model = LLM(
259
+ model_name_or_path,
260
+ dtype=dtype,
261
+ gpu_memory_utilization=0.9,
262
+ max_seq_len_to_capture=2048,
263
+ max_model_len=8192,
264
+ )
265
+ sampling_params = SamplingParams(temperature=0.1, top_p=1.00, max_tokens=300)
266
+ logger.info("Finish loading in %.2f sec." % (time.time() - start_time))
267
+
268
+ # Load the tokenizer
269
+ tokenizer = model.get_tokenizer()
270
+
271
+ tokenizer.padding_side = "left"
272
+
273
+ return model, tokenizer, sampling_params
274
+
275
+
276
+
277
+ class LLM:
278
+
279
+ def __init__(self, model_name_or_path, use_vllm=True):
280
+ self.use_vllm = use_vllm
281
+ if use_vllm:
282
+ self.chat_llm, self.tokenizer, self.sampling_params = load_vllm(model_name_or_path)
283
+ else:
284
+ self.chat_llm, self.tokenizer = load_model(model_name_or_path)
285
+
286
+ self.prompt_exceed_max_length = 0
287
+ self.fewer_than_50 = 0
288
+
289
+ def generate(self, prompt, max_tokens=300, stop=None):
290
+ if max_tokens <= 0:
291
+ self.prompt_exceed_max_length += 1
292
+ logger.warning("Prompt exceeds max length and return an empty string as answer. If this happens too many times, it is suggested to make the prompt shorter")
293
+ return ""
294
+ if max_tokens < 50:
295
+ self.fewer_than_50 += 1
296
+ logger.warning("The model can at most generate < 50 tokens. If this happens too many times, it is suggested to make the prompt shorter")
297
+
298
+ if self.use_vllm:
299
+ inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False)
300
+ self.sampling_params.n = 1 # Number of output sequences to return for the given prompt
301
+ self.sampling_params.stop_token_ids = [self.chat_llm.llm_engine.get_model_config().hf_config.eos_token_id]
302
+ self.sampling_params.max_tokens = max_tokens
303
+ output = self.chat_llm.generate(
304
+ inputs,
305
+ self.sampling_params,
306
+ use_tqdm=True,
307
+ )
308
+ generation = output[0].outputs[0].text.strip()
309
+
310
+ else:
311
+ inputs = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, return_dict=True, return_tensors="pt").to(self.chat_llm.device)
312
+ outputs = self.chat_llm.generate(
313
+ **inputs,
314
+ do_sample=True, temperature=0.1, top_p=1.0,
315
+ max_new_tokens=max_tokens,
316
+ num_return_sequences=1,
317
+ eos_token_id=[self.chat_llm.config.eos_token_id]
318
+ )
319
+ generation = self.tokenizer.decode(outputs[0][inputs['input_ids'].size(1):], skip_special_tokens=True).strip()
320
+
321
+ return generation