OrangeEye commited on
Commit
44aea13
·
1 Parent(s): 654e004
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # import os
2
  # # Set CUDA device dynamically
3
- # os.environ["CUDA_VISIBLE_DEVICES"] = "5"
4
 
5
  import spaces
6
  import torch
@@ -38,7 +38,8 @@ generate_kwargs = dict(
38
  # llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID = load_llama_guard("meta-llama/Llama-Guard-3-1B")
39
 
40
  ## RAG MODEL
41
- RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert", n_gpu=0)
 
42
 
43
  try:
44
  gr.Info("Setting up retriever, please wait...")
 
1
  # import os
2
  # # Set CUDA device dynamically
3
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ""
4
 
5
  import spaces
6
  import torch
 
38
  # llama_guard, llama_guard_tokenizer, UNSAFE_TOKEN_ID = load_llama_guard("meta-llama/Llama-Guard-3-1B")
39
 
40
  ## RAG MODEL
41
+ test_tensor = torch.tensor(2).to('cuda')
42
+ RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert", n_gpu=1)
43
 
44
  try:
45
  gr.Info("Setting up retriever, please wait...")