arsiba commited on
Commit
baadb7f
·
1 Parent(s): b6247d6

feat: restart and build from ground on

Browse files
Files changed (3) hide show
  1. .gradio/certificate.pem +31 -0
  2. app.py +62 -100
  3. requirements.txt +2 -3
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py CHANGED
@@ -1,113 +1,75 @@
1
- # -*- coding: utf-8 -*-
2
  import spaces
3
- import os, logging, traceback, pickle, gc
4
- import gradio as gr
5
- import torch
6
- import faiss
7
  import numpy as np
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
9
  from sentence_transformers import SentenceTransformer
10
- from langchain.text_splitter import RecursiveCharacterTextSplitter
11
-
12
- VECTOR_DB_DIR = "vector_db"
13
- EMBEDDING_MODEL = "BAAI/bge-large-en-v1.5"
14
- GEN_QA_MODEL = "Qwen/Qwen2-7B-Instruct"
15
- MODEL_CONTEXT_SIZE = 10000
16
- CHUNK_SIZE = 512
17
- PROMPT_RESERVE = 1024
18
 
19
- index = None
20
- document_metadata = []
21
- all_chunks = []
22
- emb_model = None
23
- gen_tokenizer = None
24
- gen_model = None
25
 
26
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
27
 
28
- def calculate_top_k():
29
- return (MODEL_CONTEXT_SIZE - PROMPT_RESERVE) // CHUNK_SIZE
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- def initialize_models():
32
- global emb_model, gen_tokenizer, gen_model
33
- emb_model = SentenceTransformer(EMBEDDING_MODEL, device="cpu")
34
- gen_tokenizer = AutoTokenizer.from_pretrained(
35
- GEN_QA_MODEL,
36
- use_fast=False,
37
- trust_remote_code=True
38
- )
39
- gen_tokenizer.pad_token = gen_tokenizer.eos_token
40
- gen_model = AutoModelForCausalLM.from_pretrained(
41
- GEN_QA_MODEL,
42
- trust_remote_code=True,
43
- device_map="cpu",
44
- torch_dtype=torch.float16,
45
- load_in_4bit=True,
46
- low_cpu_mem_usage=True
47
- )
48
- gen_model.eval()
49
- return "Models loaded."
50
 
51
- @spaces.GPU(duration=120)
52
- def load_faiss_database(progress=gr.Progress()):
53
- global index, document_metadata, all_chunks
54
- progress(0, "Reading FAISS index...")
55
- idx_path = os.path.join(VECTOR_DB_DIR, "index.faiss")
56
- if not os.path.exists(idx_path):
57
- idx_path = os.path.join(VECTOR_DB_DIR, "faiss_index.idx")
58
- cpu_index = faiss.read_index(idx_path)
59
- progress(30, "Moving FAISS index to GPU...")
60
- res = faiss.StandardGpuResources()
61
- index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
62
- progress(60, "Loading chunks & metadata...")
63
- with open(os.path.join(VECTOR_DB_DIR, "chunks.pkl"), "rb") as f:
64
- all_chunks = pickle.load(f)
65
- with open(os.path.join(VECTOR_DB_DIR, "metadata.pkl"), "rb") as f:
66
- document_metadata = pickle.load(f)
67
- progress(100, "FAISS DB ready.")
68
- return f"FAISS DB: {len(all_chunks)} chunks."
69
 
70
- @spaces.GPU(duration=120)
71
- def generate_answer(question, db_loaded):
72
- global emb_model, gen_model
73
- if not db_loaded:
74
- return "Please initialize FAISS DB first."
75
- emb_model.to("cuda")
76
- gen_model.to("cuda")
77
- torch.cuda.empty_cache()
78
- gc.collect()
79
- q_emb = emb_model.encode([question], convert_to_numpy=True)
80
- dists, ids = index.search(q_emb.astype(np.float32), calculate_top_k())
81
- ctx, sources = [], set()
82
- for i in ids[0]:
83
- m = document_metadata[i]
84
- info = f"{m['source']} (p{m['page']})"
85
- ctx.append(f"{info}: {all_chunks[i]}")
86
- sources.add(info)
87
- docs = splitter.split_text("\n\n".join(ctx))
88
- full_context = "\n\n".join(docs)
89
- messages = [
90
- {"role":"system","content":"You are a GDPR/EDPB expert."},
91
- {"role":"user","content":f"Context:\n{full_context}\n\nQ: {question}"}
92
- ]
93
- prompt = gen_tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
94
- inputs = gen_tokenizer(prompt, return_tensors="pt", padding=True).to("cuda")
95
- out = gen_model.generate(**inputs, max_new_tokens=PROMPT_RESERVE, do_sample=True)
96
- text = gen_tokenizer.decode(out[0], skip_special_tokens=True).split("Assistant:")[-1].strip()
97
- return f"Answer:\n{text}\n\nSources:\n- " + "\n- ".join(sources)
98
 
99
- with gr.Blocks(theme=gr.themes.Soft(), title="GDPR/EDPB Assistant") as demo:
100
- status = gr.Textbox(label="Status", interactive=False)
101
- init_btn = gr.Button("Initialize FAISS DB")
102
- db_loaded = gr.State(False)
103
- question = gr.Textbox(label="Legal Question", lines=3)
104
- submit_btn = gr.Button("Submit Question")
105
- answer = gr.Textbox(label="Answer", lines=12, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- demo.load(initialize_models, outputs=status)
108
- init_btn.click(load_faiss_database, outputs=status)
109
- init_btn.click(lambda: True, outputs=db_loaded)
110
- submit_btn.click(generate_answer, inputs=[question, db_loaded], outputs=answer)
 
 
 
111
 
112
  if __name__ == "__main__":
113
- demo.launch()
 
 
1
  import spaces
2
+ import pickle
 
 
 
3
  import numpy as np
4
+ import faiss
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
7
  from sentence_transformers import SentenceTransformer
8
+ import gradio as gr
9
+ from threading import Thread
 
 
 
 
 
 
10
 
11
+ index = faiss.read_index("vector_db/index.faiss")
12
+ with open("vector_db/chunks.pkl", "rb") as f:
13
+ chunks = pickle.load(f)
 
 
 
14
 
15
+ ST = SentenceTransformer("BAAI/bge-large-en-v1.5")
16
 
17
+ model_id = "Qwen/Qwen2.5-7B-Instruct"
18
+ bnb = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_use_double_quant=True,
21
+ bnb_4bit_quant_type="nf4",
22
+ bnb_4bit_compute_dtype=torch.bfloat16
23
+ )
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ quantization_config=bnb,
28
+ device_map={"": 0},
29
+ torch_dtype=torch.bfloat16
30
+ )
31
 
32
+ SYS = "You are a specialized assistant for answering questions related to legal texts from the GDPR (General Data Protection Regulation) and several Documents of the EDPB (European Data Protection Board). " \
33
+ "Your task is to provide precise and detailed answers based on the provided excerpts from the documents. " \
34
+ "Ensure that you clearly and understandably explain the relevant legal concepts. If you do not know the answer or if the information is insufficient, respond with: 'I do not know.' " \
35
+ "Avoid giving inaccurate or speculative answers."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def retrieve(q, k=3):
38
+ emb = ST.encode(q)
39
+ D, I = index.search(np.array([emb], dtype="float32"), k)
40
+ return [chunks[i] for i in I[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def make_prompt(q, docs):
43
+ return SYS + "\n\nContext:\n" + "\n".join(docs) + f"\n\nQuestion: {q}\nAnswer:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ @spaces.GPU
46
+ def qa_fn(question: str) -> str:
47
+ docs = retrieve(question, 10)
48
+ prompt = make_prompt(question, docs)[:8000]
49
+ inputs = tokenizer(prompt, return_tensors="pt")
50
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
51
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
52
+ Thread(target=model.generate, kwargs={
53
+ **inputs,
54
+ "streamer": streamer,
55
+ "max_new_tokens": 512,
56
+ "do_sample": True,
57
+ "temperature": 0.7,
58
+ "top_p": 0.9,
59
+ "eos_token_id": tokenizer.eos_token_id
60
+ }).start()
61
+ out = ""
62
+ for tok in streamer:
63
+ out += tok
64
+ return out
65
 
66
+ demo = gr.Interface(
67
+ fn=qa_fn,
68
+ inputs=gr.Textbox(lines=2, label="Your question"),
69
+ outputs=gr.Textbox(lines=10, label="Answer"),
70
+ title="GDPR QA (RAG)",
71
+ description="Ask questions on GDPR; answers are grounded in EDPB document chunks."
72
+ )
73
 
74
  if __name__ == "__main__":
75
+ demo.launch(share=True)
requirements.txt CHANGED
@@ -1,10 +1,9 @@
1
- spaces
2
  torch
3
  transformers
4
  sentence-transformers
5
- langchain
6
  faiss-gpu
7
  gradio
8
  numpy<2
9
- accelerate
10
  bitsandbytes
 
 
 
 
1
  torch
2
  transformers
3
  sentence-transformers
 
4
  faiss-gpu
5
  gradio
6
  numpy<2
 
7
  bitsandbytes
8
+ accelerate
9
+ spaces