Spaces:
Sleeping
Sleeping
Update src/RAGSample.py
Browse files- src/RAGSample.py +24 -12
src/RAGSample.py
CHANGED
@@ -347,6 +347,29 @@ def setup_retriever(use_kaggle_data: bool = False, kaggle_dataset: Optional[str]
|
|
347 |
# # Create a chain combining the prompt template and LLM
|
348 |
# return prompt | llm | StrOutputParser()
|
349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
def setup_rag_chain() -> Runnable:
|
352 |
"""Sets up the RAG chain with a prompt template and an LLM."""
|
@@ -369,18 +392,7 @@ Answer:
|
|
369 |
)
|
370 |
|
371 |
# Initialize a local Hugging Face model
|
372 |
-
hf_pipeline =
|
373 |
-
"text-generation",
|
374 |
-
model="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", # Excellent for Q&A tasks
|
375 |
-
tokenizer="deepseek-ai/DeepSeek-R1-0528-Qwen3-8B",
|
376 |
-
max_new_tokens=150, # Generate only 150 new tokens instead of max_length
|
377 |
-
temperature=0.1,
|
378 |
-
device=0 if torch.cuda.is_available() else -1,
|
379 |
-
return_full_text=False,
|
380 |
-
truncation=True, # Truncate input if too long
|
381 |
-
do_sample=True, # Enable sampling for better responses
|
382 |
-
pad_token_id=50256 # Add padding token to avoid warnings
|
383 |
-
)
|
384 |
|
385 |
# Wrap it in LangChain
|
386 |
llm = HuggingFacePipeline(pipeline=hf_pipeline)
|
|
|
347 |
# # Create a chain combining the prompt template and LLM
|
348 |
# return prompt | llm | StrOutputParser()
|
349 |
|
350 |
+
def initialize_biogpt():
|
351 |
+
try:
|
352 |
+
hf_pipeline = pipeline(
|
353 |
+
"text-generation",
|
354 |
+
model="microsoft/BioGPT",
|
355 |
+
tokenizer="microsoft/BioGPT",
|
356 |
+
max_new_tokens=150,
|
357 |
+
temperature=0.3,
|
358 |
+
device_map="auto",
|
359 |
+
torch_dtype=torch.float16,
|
360 |
+
return_full_text=False,
|
361 |
+
truncation=True,
|
362 |
+
do_sample=True,
|
363 |
+
top_p=0.9,
|
364 |
+
repetition_penalty=1.1,
|
365 |
+
pad_token_id=1,
|
366 |
+
eos_token_id=2,
|
367 |
+
)
|
368 |
+
print("BioGPT loaded successfully!")
|
369 |
+
return hf_pipeline
|
370 |
+
except Exception as e:
|
371 |
+
print(f"Error loading BioGPT: {e}")
|
372 |
+
return None
|
373 |
|
374 |
def setup_rag_chain() -> Runnable:
|
375 |
"""Sets up the RAG chain with a prompt template and an LLM."""
|
|
|
392 |
)
|
393 |
|
394 |
# Initialize a local Hugging Face model
|
395 |
+
hf_pipeline = initialize_biogpt()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
# Wrap it in LangChain
|
398 |
llm = HuggingFacePipeline(pipeline=hf_pipeline)
|