arsiba commited on
Commit
f6cf353
·
1 Parent(s): e4a8e02

feat: use nvidia llama with optional reasoning

Browse files
Files changed (1) hide show
  1. app.py +25 -27
app.py CHANGED
@@ -16,7 +16,7 @@ with open("vector_db/metadata.pkl", "rb") as f:
16
 
17
  ST = SentenceTransformer("BAAI/bge-large-en-v1.5")
18
 
19
- model_id = "microsoft/phi-2"
20
  bnb = BitsAndBytesConfig(
21
  load_in_4bit=True,
22
  bnb_4bit_use_double_quant=True,
@@ -24,6 +24,7 @@ bnb = BitsAndBytesConfig(
24
  bnb_4bit_compute_dtype=torch.bfloat16
25
  )
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
27
  model = AutoModelForCausalLM.from_pretrained(
28
  model_id,
29
  quantization_config=bnb,
@@ -33,36 +34,35 @@ model = AutoModelForCausalLM.from_pretrained(
33
  )
34
 
35
  SYS = (
36
- "You are a legal AI assistant specialized in GDPR/EDPB. "
37
- "If you cannot find an answer in the context, reply 'I do not know.' Answer this Question:"
 
38
  )
39
 
40
  def retrieve(q, k=3):
41
  emb = ST.encode(q)
42
  D, I = index.search(np.array([emb], dtype="float32"), k)
43
-
44
- docs = []
45
- file_sources = []
46
-
47
  for i in I[0]:
48
  chunk = chunks[i]
49
- metadata = metadata_dict[i]
50
- docs.append({
51
- "title": chunk,
52
- "pages": chunk
53
- })
54
- file_sources.append(metadata["source"])
55
-
56
  return docs, file_sources
57
 
58
- def make_prompt(q, docs):
59
- context = "\n\n".join(f"Title: {doc['title']}\nPages: {doc['pages']}" for doc in docs)
60
- return f"Instruct:{SYS} {q} based on the following documents:{context}\nOutput:"
 
 
 
 
 
61
 
62
  @spaces.GPU()
63
- def qa_fn(question, top_k, temperature, max_tokens):
64
  docs, file_sources = retrieve(question, top_k)
65
- prompt = make_prompt(question, docs)[:8000]
66
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
67
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
68
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
@@ -78,11 +78,8 @@ def qa_fn(question, top_k, temperature, max_tokens):
78
  output = ""
79
  for tok in streamer:
80
  output += tok
81
-
82
- think_tag_index = output.find("</think>") #change to "Answer:" after testing
83
- if think_tag_index != -1:
84
- output = output[think_tag_index + len("</think>"):].strip()
85
-
86
  return output, file_sources
87
 
88
  outputs_answer = gr.Textbox(label="Answer")
@@ -92,9 +89,10 @@ demo = gr.Interface(
92
  fn=qa_fn,
93
  inputs=[
94
  gr.Textbox(lines=2, label="Your Question"),
95
- gr.Slider(1, 7, value=4, step=1, label="Top-K Documents (How many chunks to include for context)"),
96
- gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature (Higher = more creative, lower = more focused)"),
97
- gr.Slider(64, 1024, value=512, step=64, label="Max Answer Length (Maximum tokens to generate)")
 
98
  ],
99
  outputs=[outputs_answer, outputs_sources],
100
  title="GDPR Legal Assistant",
 
16
 
17
  ST = SentenceTransformer("BAAI/bge-large-en-v1.5")
18
 
19
+ model_id = "nvidia/Llama-3.1-Nemotron-Nano-8B-v1"
20
  bnb = BitsAndBytesConfig(
21
  load_in_4bit=True,
22
  bnb_4bit_use_double_quant=True,
 
24
  bnb_4bit_compute_dtype=torch.bfloat16
25
  )
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
27
+ tokenizer.pad_token_id = tokenizer.eos_token_id
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_id,
30
  quantization_config=bnb,
 
34
  )
35
 
36
  SYS = (
37
+ "You are a legal AI assistant specialized in GDPR/EDPB. "
38
+ "If you cannot find an answer in the context, reply 'I do not know.' "
39
+ "Answer this Question:"
40
  )
41
 
42
  def retrieve(q, k=3):
43
  emb = ST.encode(q)
44
  D, I = index.search(np.array([emb], dtype="float32"), k)
45
+ docs, file_sources = [], []
 
 
 
46
  for i in I[0]:
47
  chunk = chunks[i]
48
+ meta = metadata_dict[i]
49
+ docs.append({"title": chunk, "pages": chunk})
50
+ file_sources.append(meta["source"])
 
 
 
 
51
  return docs, file_sources
52
 
53
+
54
+ def make_prompt(q, docs, reasoning_mode):
55
+ context = "\n\n".join(f"Title: {d['title']}\nPages: {d['pages']}" for d in docs)
56
+ prompt = f"detailed thinking {reasoning_mode}\n"
57
+ if reasoning_mode == "off":
58
+ prompt += "eager_mode on\n"
59
+ prompt += f"Instruct: {SYS} {q} based on the following documents:\n{context}\nOutput:"
60
+ return prompt
61
 
62
  @spaces.GPU()
63
+ def qa_fn(question, reasoning_mode, top_k, temperature, max_tokens):
64
  docs, file_sources = retrieve(question, top_k)
65
+ prompt = make_prompt(question, docs, reasoning_mode)[:8000]
66
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
67
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
68
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
 
78
  output = ""
79
  for tok in streamer:
80
  output += tok
81
+ if "</think>" in output:
82
+ output = output.split("</think>", 1)[1].strip()
 
 
 
83
  return output, file_sources
84
 
85
  outputs_answer = gr.Textbox(label="Answer")
 
89
  fn=qa_fn,
90
  inputs=[
91
  gr.Textbox(lines=2, label="Your Question"),
92
+ gr.Radio(["on", "off"], value="off", label="Reasoning Mode"),
93
+ gr.Slider(1, 7, value=4, step=1, label="Top-K Documents"),
94
+ gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="Temperature"),
95
+ gr.Slider(64, 1024, value=512, step=64, label="Max Answer Length")
96
  ],
97
  outputs=[outputs_answer, outputs_sources],
98
  title="GDPR Legal Assistant",