arsiba commited on
Commit
7568c7e
·
1 Parent(s): b3ef434

fix: revert to last functional version

Browse files
Files changed (1) hide show
  1. app.py +55 -59
app.py CHANGED
@@ -3,9 +3,10 @@ import pickle
3
  import numpy as np
4
  import faiss
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
7
  from sentence_transformers import SentenceTransformer
8
  import gradio as gr
 
9
 
10
  index = faiss.read_index("vector_db/index.faiss")
11
  with open("vector_db/chunks.pkl", "rb") as f:
@@ -15,7 +16,6 @@ with open("vector_db/metadata.pkl", "rb") as f:
15
 
16
  ST = SentenceTransformer("BAAI/bge-large-en-v1.5")
17
 
18
- chunk_embeddings = ST.encode(chunks, convert_to_numpy=True)
19
  model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
20
  bnb = BitsAndBytesConfig(
21
  load_in_4bit=True,
@@ -30,80 +30,76 @@ model = AutoModelForCausalLM.from_pretrained(
30
  device_map={"": 0},
31
  torch_dtype=torch.bfloat16
32
  )
33
- pipe = pipeline(
34
- "text-generation",
35
- model=model,
36
- tokenizer=tokenizer,
37
- device_map={"": 0},
38
- quantization_config=bnb
39
- )
40
 
41
  SYS = (
42
- "### System\n"
43
- "You are a legal AI assistant specialized in GDPR/EDPB. "
44
- "Answer based only on provided context. If uncertain, say 'I do not know.'\n"
 
45
  )
46
 
47
- def retrieve(query: str, k: int = 3):
48
- emb = ST.encode([query], convert_to_numpy=True)[0]
49
  D, I = index.search(np.array([emb], dtype="float32"), k)
50
- idxs = [int(i) for i in I[0]]
51
- selected_embs = chunk_embeddings[idxs]
52
- sims = np.dot(selected_embs, emb) / (np.linalg.norm(selected_embs, axis=1) * np.linalg.norm(emb))
53
- order = np.argsort(-sims)
54
- docs, sources = [], []
55
- for pos in order:
56
- i = idxs[pos]
57
- docs.append({"title": f"Chunk {i}", "pages": chunks[i]})
58
- sources.append(metadata_dict[i]["source"])
59
- return docs, sources
 
 
 
 
60
 
61
- def make_prompt(question: str, docs: list) -> str:
62
- context = "\n---\n".join(f"Title: {d['title']}\n{d['pages']}" for d in docs)
63
- return (
64
- f"{SYS}"
65
- f"### Context\n{context}\n"
66
- "### Question\n" + question + "\n"
67
- "### Chain of Thought\nThink step by step about relevant legal provisions.\n"
68
- "### Answer (JSON)\n{\"answer\": `Your answer`, \"sources\": []}"
69
- )
70
 
71
  @spaces.GPU()
72
  def qa_fn(question, top_k, temperature, max_tokens):
73
  docs, file_sources = retrieve(question, top_k)
74
- prompt = make_prompt(question, docs)
75
- outputs = pipe(
76
- prompt,
77
- max_new_tokens=max_tokens,
78
- do_sample=False,
79
- temperature=temperature,
80
- top_p=0.9,
81
- return_full_text=False
82
- )
83
- raw = outputs[0]["generated_text"].strip()
84
- try:
85
- json_start = raw.find('{')
86
- json_text = raw[json_start: raw.rfind('}')+1]
87
- import json; result = json.loads(json_text.replace('`', '"'))
88
- answer = result.get("answer", raw)
89
- except Exception:
90
- answer = raw
91
- return answer, file_sources
 
 
 
 
 
 
 
92
 
93
  demo = gr.Interface(
94
  fn=qa_fn,
95
  inputs=[
96
  gr.Textbox(lines=2, label="Your Question"),
97
- gr.Slider(1, 15, value=5, step=1, label="Top-K Documents"),
98
- gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature"),
99
- gr.Slider(64, 1024, value=512, step=64, label="Max Answer Length")
100
- ],
101
- outputs=[
102
- gr.Textbox(label="Answer"),
103
- gr.JSON(label="Sources (Used Files)")
104
  ],
 
105
  title="GDPR Legal Assistant",
106
- description="Enhanced RAG with reranking, structured prompts & CoT for precise legal answers.",
107
  allow_flagging="never"
108
  )
109
 
 
3
  import numpy as np
4
  import faiss
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
7
  from sentence_transformers import SentenceTransformer
8
  import gradio as gr
9
+ from threading import Thread
10
 
11
  index = faiss.read_index("vector_db/index.faiss")
12
  with open("vector_db/chunks.pkl", "rb") as f:
 
16
 
17
  ST = SentenceTransformer("BAAI/bge-large-en-v1.5")
18
 
 
19
  model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
20
  bnb = BitsAndBytesConfig(
21
  load_in_4bit=True,
 
30
  device_map={"": 0},
31
  torch_dtype=torch.bfloat16
32
  )
 
 
 
 
 
 
 
33
 
34
  SYS = (
35
+ "You are a legal AI assistant. Answer the user's question "
36
+ "based only on the given legal context from GDPR and EDPB documents. "
37
+ "Be accurate, use clear language, and do not make assumptions. "
38
+ "If unsure, say: 'I do not know.'"
39
  )
40
 
41
+ def retrieve(q, k=3):
42
+ emb = ST.encode(q)
43
  D, I = index.search(np.array([emb], dtype="float32"), k)
44
+
45
+ docs = []
46
+ file_sources = []
47
+
48
+ for i in I[0]:
49
+ chunk = chunks[i]
50
+ metadata = metadata_dict[i]
51
+ docs.append({
52
+ "title": chunk,
53
+ "pages": chunk
54
+ })
55
+ file_sources.append(metadata["source"])
56
+
57
+ return docs, file_sources
58
 
59
+ def make_prompt(q, docs):
60
+ context = "\n\n".join(f"Title: {doc['title']}\nPages: {doc['pages']}" for doc in docs)
61
+ return f"{SYS}\n\nContext:\n{context}\n\nQuestion:\n{q}\n\nAnswer:"
 
 
 
 
 
 
62
 
63
  @spaces.GPU()
64
  def qa_fn(question, top_k, temperature, max_tokens):
65
  docs, file_sources = retrieve(question, top_k)
66
+ prompt = make_prompt(question, docs)[:8000]
67
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
68
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
69
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
70
+ Thread(target=model.generate, kwargs={
71
+ **inputs,
72
+ "streamer": streamer,
73
+ "max_new_tokens": max_tokens,
74
+ "do_sample": False,
75
+ "temperature": temperature,
76
+ "top_p": 0.9,
77
+ "eos_token_id": tokenizer.eos_token_id
78
+ }).start()
79
+ output = ""
80
+ for tok in streamer:
81
+ output += tok
82
+
83
+ think_tag_index = output.find("</think>") #change to "Answer:" after testing
84
+ if think_tag_index != -1:
85
+ output = output[think_tag_index + len("</think>"):].strip()
86
+
87
+ return output, file_sources
88
+
89
+ outputs_answer = gr.Textbox(label="Answer")
90
+ outputs_sources = gr.JSON(label="Sources (Used Files)")
91
 
92
  demo = gr.Interface(
93
  fn=qa_fn,
94
  inputs=[
95
  gr.Textbox(lines=2, label="Your Question"),
96
+ gr.Slider(1, 15, value=5, step=1, label="Top-K Documents (How many chunks to include for context)"),
97
+ gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature (Higher = more creative, lower = more focused)"),
98
+ gr.Slider(64, 1024, value=512, step=64, label="Max Answer Length (Maximum tokens to generate)")
 
 
 
 
99
  ],
100
+ outputs=[outputs_answer, outputs_sources],
101
  title="GDPR Legal Assistant",
102
+ description="Ask any question about GDPR or EDPB documents. The response includes used files and chunks.",
103
  allow_flagging="never"
104
  )
105