Spaces:
Running
on
Zero
Running
on
Zero
feat: use nvidia llama with optional reasoning
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ with open("vector_db/metadata.pkl", "rb") as f:
|
|
16 |
|
17 |
ST = SentenceTransformer("BAAI/bge-large-en-v1.5")
|
18 |
|
19 |
-
model_id = "
|
20 |
bnb = BitsAndBytesConfig(
|
21 |
load_in_4bit=True,
|
22 |
bnb_4bit_use_double_quant=True,
|
@@ -24,6 +24,7 @@ bnb = BitsAndBytesConfig(
|
|
24 |
bnb_4bit_compute_dtype=torch.bfloat16
|
25 |
)
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
|
27 |
model = AutoModelForCausalLM.from_pretrained(
|
28 |
model_id,
|
29 |
quantization_config=bnb,
|
@@ -33,36 +34,35 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
33 |
)
|
34 |
|
35 |
SYS = (
|
36 |
-
"You are a legal AI assistant specialized in GDPR/EDPB. "
|
37 |
-
"If you cannot find an answer in the context, reply 'I do not know.'
|
|
|
38 |
)
|
39 |
|
40 |
def retrieve(q, k=3):
|
41 |
emb = ST.encode(q)
|
42 |
D, I = index.search(np.array([emb], dtype="float32"), k)
|
43 |
-
|
44 |
-
docs = []
|
45 |
-
file_sources = []
|
46 |
-
|
47 |
for i in I[0]:
|
48 |
chunk = chunks[i]
|
49 |
-
|
50 |
-
docs.append({
|
51 |
-
|
52 |
-
"pages": chunk
|
53 |
-
})
|
54 |
-
file_sources.append(metadata["source"])
|
55 |
-
|
56 |
return docs, file_sources
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
@spaces.GPU()
|
63 |
-
def qa_fn(question, top_k, temperature, max_tokens):
|
64 |
docs, file_sources = retrieve(question, top_k)
|
65 |
-
prompt = make_prompt(question, docs)[:8000]
|
66 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
67 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
68 |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
@@ -78,11 +78,8 @@ def qa_fn(question, top_k, temperature, max_tokens):
|
|
78 |
output = ""
|
79 |
for tok in streamer:
|
80 |
output += tok
|
81 |
-
|
82 |
-
|
83 |
-
if think_tag_index != -1:
|
84 |
-
output = output[think_tag_index + len("</think>"):].strip()
|
85 |
-
|
86 |
return output, file_sources
|
87 |
|
88 |
outputs_answer = gr.Textbox(label="Answer")
|
@@ -92,9 +89,10 @@ demo = gr.Interface(
|
|
92 |
fn=qa_fn,
|
93 |
inputs=[
|
94 |
gr.Textbox(lines=2, label="Your Question"),
|
95 |
-
gr.
|
96 |
-
gr.Slider(
|
97 |
-
gr.Slider(
|
|
|
98 |
],
|
99 |
outputs=[outputs_answer, outputs_sources],
|
100 |
title="GDPR Legal Assistant",
|
|
|
16 |
|
17 |
ST = SentenceTransformer("BAAI/bge-large-en-v1.5")
|
18 |
|
19 |
+
model_id = "nvidia/Llama-3.1-Nemotron-Nano-8B-v1"
|
20 |
bnb = BitsAndBytesConfig(
|
21 |
load_in_4bit=True,
|
22 |
bnb_4bit_use_double_quant=True,
|
|
|
24 |
bnb_4bit_compute_dtype=torch.bfloat16
|
25 |
)
|
26 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
27 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
28 |
model = AutoModelForCausalLM.from_pretrained(
|
29 |
model_id,
|
30 |
quantization_config=bnb,
|
|
|
34 |
)
|
35 |
|
36 |
SYS = (
|
37 |
+
"You are a legal AI assistant specialized in GDPR/EDPB. "
|
38 |
+
"If you cannot find an answer in the context, reply 'I do not know.' "
|
39 |
+
"Answer this Question:"
|
40 |
)
|
41 |
|
42 |
def retrieve(q, k=3):
|
43 |
emb = ST.encode(q)
|
44 |
D, I = index.search(np.array([emb], dtype="float32"), k)
|
45 |
+
docs, file_sources = [], []
|
|
|
|
|
|
|
46 |
for i in I[0]:
|
47 |
chunk = chunks[i]
|
48 |
+
meta = metadata_dict[i]
|
49 |
+
docs.append({"title": chunk, "pages": chunk})
|
50 |
+
file_sources.append(meta["source"])
|
|
|
|
|
|
|
|
|
51 |
return docs, file_sources
|
52 |
|
53 |
+
|
54 |
+
def make_prompt(q, docs, reasoning_mode):
|
55 |
+
context = "\n\n".join(f"Title: {d['title']}\nPages: {d['pages']}" for d in docs)
|
56 |
+
prompt = f"detailed thinking {reasoning_mode}\n"
|
57 |
+
if reasoning_mode == "off":
|
58 |
+
prompt += "eager_mode on\n"
|
59 |
+
prompt += f"Instruct: {SYS} {q} based on the following documents:\n{context}\nOutput:"
|
60 |
+
return prompt
|
61 |
|
62 |
@spaces.GPU()
|
63 |
+
def qa_fn(question, reasoning_mode, top_k, temperature, max_tokens):
|
64 |
docs, file_sources = retrieve(question, top_k)
|
65 |
+
prompt = make_prompt(question, docs, reasoning_mode)[:8000]
|
66 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
67 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
68 |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
|
|
78 |
output = ""
|
79 |
for tok in streamer:
|
80 |
output += tok
|
81 |
+
if "</think>" in output:
|
82 |
+
output = output.split("</think>", 1)[1].strip()
|
|
|
|
|
|
|
83 |
return output, file_sources
|
84 |
|
85 |
outputs_answer = gr.Textbox(label="Answer")
|
|
|
89 |
fn=qa_fn,
|
90 |
inputs=[
|
91 |
gr.Textbox(lines=2, label="Your Question"),
|
92 |
+
gr.Radio(["on", "off"], value="off", label="Reasoning Mode"),
|
93 |
+
gr.Slider(1, 7, value=4, step=1, label="Top-K Documents"),
|
94 |
+
gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="Temperature"),
|
95 |
+
gr.Slider(64, 1024, value=512, step=64, label="Max Answer Length")
|
96 |
],
|
97 |
outputs=[outputs_answer, outputs_sources],
|
98 |
title="GDPR Legal Assistant",
|