arsiba commited on
Commit
ceb92f6
Β·
1 Parent(s): ff30c53

fix: move to gpu when running

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -8,18 +8,10 @@ from sentence_transformers import SentenceTransformer
8
  import gradio as gr
9
  from threading import Thread
10
 
11
- # FAISS Index auf GPU laden
12
  print("Loading FAISS index...")
13
  cpu_index = faiss.read_index("vector_db/index.faiss")
14
-
15
- # PrΓΌfen ob GPU verfΓΌgbar ist und Index auf GPU verschieben
16
- if torch.cuda.is_available():
17
- print("Moving FAISS to GPU...")
18
- res = faiss.StandardGpuResources()
19
- index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
20
- else:
21
- print("GPU not available, using CPU index")
22
- index = cpu_index
23
 
24
  with open("vector_db/chunks.pkl", "rb") as f:
25
  chunks = pickle.load(f)
@@ -55,8 +47,21 @@ SYS = (
55
 
56
  @spaces.GPU()
57
  def retrieve(q, k=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  emb = ST.encode(q)
59
- D, I = index.search(np.array([emb], dtype="float32"), k)
60
  docs, file_sources = [], []
61
  for i in I[0]:
62
  chunk = chunks[i]
 
8
  import gradio as gr
9
  from threading import Thread
10
 
11
+ # FAISS Index laden (vorerst auf CPU)
12
  print("Loading FAISS index...")
13
  cpu_index = faiss.read_index("vector_db/index.faiss")
14
+ gpu_index = None # GPU Index wird spΓ€ter erstellt
 
 
 
 
 
 
 
 
15
 
16
  with open("vector_db/chunks.pkl", "rb") as f:
17
  chunks = pickle.load(f)
 
47
 
48
  @spaces.GPU()
49
  def retrieve(q, k=3):
50
+ global gpu_index, cpu_index
51
+
52
+ # GPU Index beim ersten Aufruf erstellen
53
+ if gpu_index is None:
54
+ try:
55
+ print("Creating GPU index...")
56
+ res = faiss.StandardGpuResources()
57
+ gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
58
+ print("FAISS index successfully moved to GPU")
59
+ except Exception as e:
60
+ print(f"Failed to move FAISS to GPU: {e}")
61
+ gpu_index = cpu_index # Fallback to CPU
62
+
63
  emb = ST.encode(q)
64
+ D, I = gpu_index.search(np.array([emb], dtype="float32"), k)
65
  docs, file_sources = [], []
66
  for i in I[0]:
67
  chunk = chunks[i]