import os import gradio as gr import re from collections import Counter from llama_cpp import Llama from huggingface_hub import hf_hub_download # Configuration: set environment variables for model repository and file HF_REPO_ID = os.getenv("HF_REPO_ID", "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF") HF_MODEL_BASENAME = os.getenv("HF_MODEL_BASENAME", "Meta-Llama-3-8B-Instruct.Q8_0.gguf") # Download or locate the quantized LLaMA model model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_BASENAME) # Initialize LLaMA via llama_cpp llm = Llama( model_path=model_path, n_threads=int(os.getenv("LLAMA_THREADS", "4")), n_batch=int(os.getenv("LLAMA_BATCH", "256")), n_gpu_layers=int(os.getenv("LLAMA_GPU_LAYERS", "43")), n_ctx=int(os.getenv("LLAMA_CTX", "8192")) ) # Prompt templates from notebook system_prompt_few_shot = """ SYSTEM: You are an AI medical assistant specializing in differential diagnosis. Generate the most likely list of diagnoses based on examples. USER: A 45-year-old male, fever, cough, fatigue. SYSTEM: [Flu, COVID-19, Pneumonia] USER: A 30-year-old female, severe abdominal pain, nausea. SYSTEM: [Appendicitis, Gallstones, Gastritis] USER: A 10-year-old female, wheezing. SYSTEM: [Asthma, Respiratory Infection] USER: """ system_prompt_cot = """ SYSTEM: You are a medical expert performing differential diagnosis through step-by-step reasoning. Provide intermediate reasoning and final diagnoses. USER: """ system_prompt_tot = """ SYSTEM: You are a medical expert using a tree-of-thought approach for differential diagnosis. Construct a reasoning tree then provide final diagnoses. USER: """ def lcpp_llm(prompt, max_tokens=2048, temperature=0, stop=["USER"]): return llm(prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop) def extract_list(text): match = re.search(r'\[(.*?)\]', text) if match: return [item.strip() for item in match.group(1).split(",")] return [] def determine_most_probable(few_shot, cot, tot): counts = Counter(few_shot + cot + tot) if not counts: return "No Clear Diagnosis" max_occ = max(counts.values()) for diag, cnt in counts.items(): if cnt == max_occ: return diag def medical_diagnosis(symptoms: str): try: # Generate responses resp_few = lcpp_llm(system_prompt_few_shot + symptoms) resp_cot = lcpp_llm(system_prompt_cot + symptoms) resp_tot = lcpp_llm(system_prompt_tot + symptoms) # Extract text text_few = resp_few['choices'][0]['text'].strip() text_cot = resp_cot['choices'][0]['text'].strip() text_tot = resp_tot['choices'][0]['text'].strip() # Parse lists few = extract_list(text_few) cot = extract_list(text_cot) tot = extract_list(text_tot) most = determine_most_probable(few, cot, tot) # Format Markdown output return f""" ### Differential Diagnosis Results **Few-Shot Diagnoses:** {', '.join(few) if few else 'No Diagnosis'} **Chain-of-Thought Diagnoses:** {', '.join(cot) if cot else 'No Diagnosis'} **Tree-of-Thought Diagnoses:** {', '.join(tot) if tot else 'No Diagnosis'} **Most Probable Diagnosis:** {most} """ except Exception as e: return f"Error: {e}" # Gradio app definition with gr.Blocks() as demo: gr.Markdown("# Differential Diagnosis Explorer (Local LLaMA)") cond = gr.Textbox(label="Patient Condition", placeholder="A 35-year-old male, fever, wheezing, nausea.") out = gr.Markdown() btn = gr.Button("Diagnose") btn.click(fn=medical_diagnosis, inputs=cond, outputs=out) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.getenv('PORT', 7860)))