arsiba commited on
Commit
099d11b
Β·
1 Parent(s): f6cf353

feat: further optimize for gpu speed

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -8,13 +8,18 @@ from sentence_transformers import SentenceTransformer
8
  import gradio as gr
9
  from threading import Thread
10
 
11
- index = faiss.read_index("vector_db/index.faiss")
 
 
 
12
  with open("vector_db/chunks.pkl", "rb") as f:
13
  chunks = pickle.load(f)
14
  with open("vector_db/metadata.pkl", "rb") as f:
15
  metadata_dict = pickle.load(f)
16
 
17
- ST = SentenceTransformer("BAAI/bge-large-en-v1.5")
 
 
18
 
19
  model_id = "nvidia/Llama-3.1-Nemotron-Nano-8B-v1"
20
  bnb = BitsAndBytesConfig(
@@ -39,18 +44,20 @@ SYS = (
39
  "Answer this Question:"
40
  )
41
 
 
42
  def retrieve(q, k=3):
43
- emb = ST.encode(q)
44
- D, I = index.search(np.array([emb], dtype="float32"), k)
 
 
45
  docs, file_sources = [], []
46
- for i in I[0]:
47
  chunk = chunks[i]
48
  meta = metadata_dict[i]
49
  docs.append({"title": chunk, "pages": chunk})
50
  file_sources.append(meta["source"])
51
  return docs, file_sources
52
 
53
-
54
  def make_prompt(q, docs, reasoning_mode):
55
  context = "\n\n".join(f"Title: {d['title']}\nPages: {d['pages']}" for d in docs)
56
  prompt = f"detailed thinking {reasoning_mode}\n"
@@ -59,7 +66,7 @@ def make_prompt(q, docs, reasoning_mode):
59
  prompt += f"Instruct: {SYS} {q} based on the following documents:\n{context}\nOutput:"
60
  return prompt
61
 
62
- @spaces.GPU()
63
  def qa_fn(question, reasoning_mode, top_k, temperature, max_tokens):
64
  docs, file_sources = retrieve(question, top_k)
65
  prompt = make_prompt(question, docs, reasoning_mode)[:8000]
 
8
  import gradio as gr
9
  from threading import Thread
10
 
11
+ cpu_index = faiss.read_index("vector_db/index.faiss")
12
+ res = faiss.StandardGpuResources()
13
+ index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
14
+
15
  with open("vector_db/chunks.pkl", "rb") as f:
16
  chunks = pickle.load(f)
17
  with open("vector_db/metadata.pkl", "rb") as f:
18
  metadata_dict = pickle.load(f)
19
 
20
+ ST = SentenceTransformer("BAAI/bge-large-en-v1.5", device="cuda")
21
+ ST.cuda()
22
+ ST.compile()
23
 
24
  model_id = "nvidia/Llama-3.1-Nemotron-Nano-8B-v1"
25
  bnb = BitsAndBytesConfig(
 
44
  "Answer this Question:"
45
  )
46
 
47
+ @spaces.GPU(duration=20)
48
  def retrieve(q, k=3):
49
+ emb = ST.encode(q, convert_to_tensor=True)
50
+ emb = emb.unsqueeze(0).to('cuda')
51
+ D, I = index.search(emb, k)
52
+ ids = I.cpu().numpy().tolist()[0]
53
  docs, file_sources = [], []
54
+ for i in ids:
55
  chunk = chunks[i]
56
  meta = metadata_dict[i]
57
  docs.append({"title": chunk, "pages": chunk})
58
  file_sources.append(meta["source"])
59
  return docs, file_sources
60
 
 
61
  def make_prompt(q, docs, reasoning_mode):
62
  context = "\n\n".join(f"Title: {d['title']}\nPages: {d['pages']}" for d in docs)
63
  prompt = f"detailed thinking {reasoning_mode}\n"
 
66
  prompt += f"Instruct: {SYS} {q} based on the following documents:\n{context}\nOutput:"
67
  return prompt
68
 
69
+ @spaces.GPU(duration=20)
70
  def qa_fn(question, reasoning_mode, top_k, temperature, max_tokens):
71
  docs, file_sources = retrieve(question, top_k)
72
  prompt = make_prompt(question, docs, reasoning_mode)[:8000]