brendon-ai commited on
Commit
9fb2feb
·
verified ·
1 Parent(s): 41eefcc

Update src/RAGSample.py

Browse files
Files changed (1) hide show
  1. src/RAGSample.py +97 -45
src/RAGSample.py CHANGED
@@ -370,33 +370,62 @@ Answer:
370
  input_variables=["question", "documents"],
371
  )
372
 
373
- tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT")
374
- model = AutoModelForCausalLM.from_pretrained(
375
- "microsoft/BioGPT",
376
- device_map="auto",
377
- torch_dtype=torch.float16
378
- )
379
-
380
- # Fix the tokenizer configuration
381
- if tokenizer.pad_token is None:
382
- tokenizer.pad_token = tokenizer.eos_token
383
-
384
- # Initialize a local Hugging Face model
385
- hf_pipeline = pipeline(
386
- "text-generation",
387
- model=model,
388
- tokenizer=tokenizer,
389
- max_new_tokens=100, # Reduced for stability
390
- max_length=1024, # BioGPT's context length
391
- temperature=0.2, # Lower for more focused responses
392
- device_map="auto",
393
- torch_dtype=torch.float16,
394
- return_full_text=False,
395
- truncation=True,
396
- do_sample=True,
397
- pad_token_id=1,
398
- eos_token_id=2
399
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  # Wrap it in LangChain
402
  llm = HuggingFacePipeline(pipeline=hf_pipeline)
@@ -404,7 +433,8 @@ Answer:
404
  # Create a chain combining the prompt template and LLM
405
  return prompt | llm | StrOutputParser()
406
 
407
- # Define the RAG application class
 
408
  class RAGApplication:
409
  def __init__(self, retriever: BaseRetriever, rag_chain: Runnable):
410
  self.retriever = retriever
@@ -412,23 +442,45 @@ class RAGApplication:
412
 
413
  def run(self, question: str) -> str:
414
  """Runs the RAG pipeline for a given question."""
415
- # Retrieve relevant documents
416
- documents = self.retriever.invoke(question)
417
-
418
- # Debug: Print retrieved documents
419
- print(f"\nDEBUG: Retrieved {len(documents)} documents for question: '{question}'")
420
- for i, doc in enumerate(documents):
421
- print(f"DEBUG: Document {i+1}: {doc.page_content[:200]}...")
422
-
423
- # Extract content from retrieved documents
424
- doc_texts = "\n\n".join([doc.page_content for doc in documents])
425
-
426
- # Debug: Print the combined document text
427
- print(f"DEBUG: Combined document text: {doc_texts[:300]}...")
428
-
429
- # Get the answer from the language model
430
- answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
431
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  # Main execution block
434
  if __name__ == "__main__":
 
370
  input_variables=["question", "documents"],
371
  )
372
 
373
+ try:
374
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/BioGPT")
375
+ model = AutoModelForCausalLM.from_pretrained(
376
+ "microsoft/BioGPT",
377
+ device_map="auto",
378
+ torch_dtype=torch.float16
379
+ )
380
+
381
+ # Fix the tokenizer configuration properly
382
+ if tokenizer.pad_token is None:
383
+ tokenizer.pad_token = tokenizer.eos_token
384
+
385
+ print(f"Tokenizer pad_token_id: {tokenizer.pad_token_id}")
386
+ print(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}")
387
+
388
+ # Initialize pipeline with correct token IDs from tokenizer
389
+ hf_pipeline = pipeline(
390
+ "text-generation",
391
+ model=model,
392
+ tokenizer=tokenizer,
393
+ max_new_tokens=50, # Start small for testing
394
+ temperature=0.2,
395
+ return_full_text=False,
396
+ do_sample=True,
397
+ # Use actual tokenizer token IDs, not hardcoded values
398
+ pad_token_id=tokenizer.pad_token_id,
399
+ eos_token_id=tokenizer.eos_token_id,
400
+ clean_up_tokenization_spaces=True
401
+ )
402
+
403
+ # Test the pipeline with a simple input
404
+ test_input = "What is diabetes?"
405
+ print(f"Testing pipeline with: {test_input}")
406
+ test_result = hf_pipeline(test_input)
407
+ print(f"Pipeline test successful: {test_result}")
408
+
409
+ except Exception as e:
410
+ print(f"Error setting up BioGPT: {e}")
411
+ print("Falling back to DistilGPT-2...")
412
+
413
+ # Fallback to a more stable model
414
+ hf_pipeline = pipeline(
415
+ "text-generation",
416
+ model="distilgpt2",
417
+ max_new_tokens=50,
418
+ temperature=0.2,
419
+ return_full_text=False,
420
+ do_sample=True,
421
+ clean_up_tokenization_spaces=True
422
+ )
423
+
424
+ # Test the fallback pipeline
425
+ test_input = "What is diabetes?"
426
+ print(f"Testing fallback pipeline with: {test_input}")
427
+ test_result = hf_pipeline(test_input)
428
+ print(f"Fallback pipeline test successful: {test_result}")
429
 
430
  # Wrap it in LangChain
431
  llm = HuggingFacePipeline(pipeline=hf_pipeline)
 
433
  # Create a chain combining the prompt template and LLM
434
  return prompt | llm | StrOutputParser()
435
 
436
+
437
+ # Also update the RAG application class with better error handling
438
  class RAGApplication:
439
  def __init__(self, retriever: BaseRetriever, rag_chain: Runnable):
440
  self.retriever = retriever
 
442
 
443
  def run(self, question: str) -> str:
444
  """Runs the RAG pipeline for a given question."""
445
+ try:
446
+ # Input validation
447
+ if not question or not question.strip():
448
+ return "Please provide a valid question."
449
+
450
+ question = question.strip()
451
+ print(f"\nProcessing question: '{question}'")
452
+
453
+ # Retrieve relevant documents
454
+ documents = self.retriever.invoke(question)
455
+
456
+ # Debug: Print retrieved documents
457
+ print(f"DEBUG: Retrieved {len(documents)} documents")
458
+ for i, doc in enumerate(documents):
459
+ print(f"DEBUG: Document {i+1}: {doc.page_content[:200]}...")
460
+
461
+ # Extract content from retrieved documents
462
+ doc_texts = "\n\n".join([doc.page_content for doc in documents])
463
+
464
+ # Limit the total input length to prevent token overflow
465
+ max_input_length = 500 # Conservative limit
466
+ if len(doc_texts) > max_input_length:
467
+ doc_texts = doc_texts[:max_input_length] + "..."
468
+ print(f"DEBUG: Truncated document text to {max_input_length} characters")
469
+
470
+ print(f"DEBUG: Combined document text length: {len(doc_texts)}")
471
+
472
+ # Get the answer from the language model
473
+ print("DEBUG: Calling language model...")
474
+ answer = self.rag_chain.invoke({"question": question, "documents": doc_texts})
475
+ print(f"DEBUG: Language model response: {answer}")
476
+
477
+ return answer
478
+
479
+ except Exception as e:
480
+ print(f"Error in RAG application: {str(e)}")
481
+ import traceback
482
+ traceback.print_exc()
483
+ return f"I apologize, but I encountered an error processing your question: {str(e)}. Please try rephrasing it or ask a different question."
484
 
485
  # Main execution block
486
  if __name__ == "__main__":