Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import pickle | |
import numpy as np | |
import faiss | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer | |
from sentence_transformers import SentenceTransformer | |
import gradio as gr | |
from threading import Thread | |
index = faiss.read_index("vector_db/index.faiss") | |
with open("vector_db/chunks.pkl", "rb") as f: | |
chunks = pickle.load(f) | |
with open("vector_db/metadata.pkl", "rb") as f: | |
metadata_dict = pickle.load(f) | |
ST = SentenceTransformer("BAAI/bge-large-en-v1.5") | |
github_base_url = "https://github.com/arsiba/EDPB-AI/blob/main/" | |
model_id = "HuggingFaceH4/zephyr-7b-beta" | |
bnb = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
quantization_config=bnb, | |
device_map={"": 0}, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
) | |
SYS = ( | |
"You are a legal AI assistant specialized in GDPR/EDPB." | |
"If you cannot find an answer in the context, it's okay to speculate. But if so, make it clear." | |
"Answer this Question:" | |
) | |
def retrieve(q, k=3): | |
emb = ST.encode(q) | |
D, I = index.search(np.array([emb], dtype="float32"), k) | |
docs, file_sources = [], [] | |
for i in I[0]: | |
chunk = chunks[i] | |
meta = metadata_dict[i] | |
docs.append({"title": meta, "pages": chunk}) | |
file_sources.append(meta) | |
return docs, file_sources | |
def make_prompt(q, docs): | |
context = "\n\n".join(f"Title: {d['title']}\nPages: {d['pages']}" for d in docs) | |
prompt = f"detailed thinking off\n" | |
prompt += f"Instruct: {SYS} {q} based on the following documents:\n{context}\nOutput:" | |
return prompt | |
def build_markdown_links(file_input): | |
lines = [] | |
for idx, item in enumerate(file_input, start=1): | |
url = f"{github_base_url}/{item['directory']}/{item['source']}" | |
line = f"**Source {idx}:** [{item['source']}]({url}) on page {item['page']}" | |
lines.append(line) | |
return "\n\n".join(lines) | |
def build_markdown_chunks(docs): | |
lines = [] | |
for idx, d in enumerate(docs, start=1): | |
title = d['title']['source'] | |
page = d['title']['page'] | |
text = d['pages'] | |
lines.append(f"**Chunk {idx}:** {title} on page {page}\n\n{text}") | |
return "\n\n".join(lines) | |
def qa_fn(faiss_search, question, top_k, temperature, max_tokens): | |
docs, file_sources = retrieve(faiss_search, top_k) | |
file_links = build_markdown_links(file_sources) | |
markdown_chunks = build_markdown_chunks(docs) | |
prompt = make_prompt(question, docs)[:8000] | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
Thread(target=model.generate, kwargs={ | |
**inputs, | |
"streamer": streamer, | |
"max_new_tokens": max_tokens, | |
"do_sample": False, | |
"temperature": temperature, | |
"top_p": 0.9, | |
"eos_token_id": tokenizer.eos_token_id | |
}).start() | |
output = "" | |
for tok in streamer: | |
output += tok | |
if "Output:" in output: | |
output = output.split("Output:", 1)[1].strip() | |
return "\n# Generated Answer\n", output,"\n# Used Documents\n", file_links, "\n# Used Context\n", markdown_chunks | |
heading_answer = gr.Markdown(label="Answer Heading") | |
outputs_answer = gr.Textbox(label="Answer") | |
heading_links = gr.Markdown(label="Links Heading") | |
heading_chunks = gr.Markdown(label="Chunks Heading") | |
outputs_link = gr.Markdown(label="Source Link") | |
outputs_chunks = gr.Markdown(label="Used Chunks") | |
demo = gr.Interface( | |
fn=qa_fn, | |
inputs=[ | |
gr.Textbox(lines=4, label="What Documents are you looking for?", placeholder="Please change to get propper results:\nDocuments covering the EDPB’s stance on automated decision-making, particularly profiling, under the GDPR. Guidelines on how organizations should inform data subjects about automated decisions and the rights of individuals to object to such decisions."), | |
gr.Textbox(lines=1, label="What is your question?", placeholder="Please change to get propper results:\nWhat does the EDPB recommend regarding automated decision-making and profiling under the GDPR, and what rights do individuals have in relation to such decisions?"), | |
], | |
additional_inputs=[ | |
gr.Slider(1, 10, value=7, step=1, label="Top-K Documents"), | |
gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="Temperature"), | |
gr.Slider(64, 1024, value=512, step=64, label="Max Answer Length") | |
], | |
additional_inputs_accordion="Advanced Options", | |
outputs=[ | |
heading_answer, | |
outputs_answer, | |
heading_links, | |
outputs_link, | |
heading_chunks, | |
outputs_chunks | |
], | |
title="GDPR Legal Assistant", | |
description="Ask any question about GDPR or EDPB documents.", | |
allow_flagging="never", | |
fill_width=True, | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |