eaglelandsonce's picture
Create app.py
437a9af verified
import os
import gradio as gr
import re
from collections import Counter
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
# Configuration: set environment variables for model repository and file
HF_REPO_ID = os.getenv("HF_REPO_ID", "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF")
HF_MODEL_BASENAME = os.getenv("HF_MODEL_BASENAME", "Meta-Llama-3-8B-Instruct.Q8_0.gguf")
# Download or locate the quantized LLaMA model
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_BASENAME)
# Initialize LLaMA via llama_cpp
llm = Llama(
model_path=model_path,
n_threads=int(os.getenv("LLAMA_THREADS", "4")),
n_batch=int(os.getenv("LLAMA_BATCH", "256")),
n_gpu_layers=int(os.getenv("LLAMA_GPU_LAYERS", "43")),
n_ctx=int(os.getenv("LLAMA_CTX", "8192"))
)
# Prompt templates from notebook
system_prompt_few_shot = """
SYSTEM:
You are an AI medical assistant specializing in differential diagnosis.
Generate the most likely list of diagnoses based on examples.
USER: A 45-year-old male, fever, cough, fatigue.
SYSTEM: [Flu, COVID-19, Pneumonia]
USER: A 30-year-old female, severe abdominal pain, nausea.
SYSTEM: [Appendicitis, Gallstones, Gastritis]
USER: A 10-year-old female, wheezing.
SYSTEM: [Asthma, Respiratory Infection]
USER:
"""
system_prompt_cot = """
SYSTEM:
You are a medical expert performing differential diagnosis through step-by-step reasoning.
Provide intermediate reasoning and final diagnoses.
USER:
"""
system_prompt_tot = """
SYSTEM:
You are a medical expert using a tree-of-thought approach for differential diagnosis.
Construct a reasoning tree then provide final diagnoses.
USER:
"""
def lcpp_llm(prompt, max_tokens=2048, temperature=0, stop=["USER"]):
return llm(prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
def extract_list(text):
match = re.search(r'\[(.*?)\]', text)
if match:
return [item.strip() for item in match.group(1).split(",")]
return []
def determine_most_probable(few_shot, cot, tot):
counts = Counter(few_shot + cot + tot)
if not counts:
return "No Clear Diagnosis"
max_occ = max(counts.values())
for diag, cnt in counts.items():
if cnt == max_occ:
return diag
def medical_diagnosis(symptoms: str):
try:
# Generate responses
resp_few = lcpp_llm(system_prompt_few_shot + symptoms)
resp_cot = lcpp_llm(system_prompt_cot + symptoms)
resp_tot = lcpp_llm(system_prompt_tot + symptoms)
# Extract text
text_few = resp_few['choices'][0]['text'].strip()
text_cot = resp_cot['choices'][0]['text'].strip()
text_tot = resp_tot['choices'][0]['text'].strip()
# Parse lists
few = extract_list(text_few)
cot = extract_list(text_cot)
tot = extract_list(text_tot)
most = determine_most_probable(few, cot, tot)
# Format Markdown output
return f"""
### Differential Diagnosis Results
**Few-Shot Diagnoses:** {', '.join(few) if few else 'No Diagnosis'}
**Chain-of-Thought Diagnoses:** {', '.join(cot) if cot else 'No Diagnosis'}
**Tree-of-Thought Diagnoses:** {', '.join(tot) if tot else 'No Diagnosis'}
**Most Probable Diagnosis:** {most}
"""
except Exception as e:
return f"Error: {e}"
# Gradio app definition
with gr.Blocks() as demo:
gr.Markdown("# Differential Diagnosis Explorer (Local LLaMA)")
cond = gr.Textbox(label="Patient Condition", placeholder="A 35-year-old male, fever, wheezing, nausea.")
out = gr.Markdown()
btn = gr.Button("Diagnose")
btn.click(fn=medical_diagnosis, inputs=cond, outputs=out)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv('PORT', 7860)))