Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import re | |
from collections import Counter | |
from llama_cpp import Llama | |
from huggingface_hub import hf_hub_download | |
# Configuration: set environment variables for model repository and file | |
HF_REPO_ID = os.getenv("HF_REPO_ID", "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF") | |
HF_MODEL_BASENAME = os.getenv("HF_MODEL_BASENAME", "Meta-Llama-3-8B-Instruct.Q8_0.gguf") | |
# Download or locate the quantized LLaMA model | |
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_BASENAME) | |
# Initialize LLaMA via llama_cpp | |
llm = Llama( | |
model_path=model_path, | |
n_threads=int(os.getenv("LLAMA_THREADS", "4")), | |
n_batch=int(os.getenv("LLAMA_BATCH", "256")), | |
n_gpu_layers=int(os.getenv("LLAMA_GPU_LAYERS", "43")), | |
n_ctx=int(os.getenv("LLAMA_CTX", "8192")) | |
) | |
# Prompt templates from notebook | |
system_prompt_few_shot = """ | |
SYSTEM: | |
You are an AI medical assistant specializing in differential diagnosis. | |
Generate the most likely list of diagnoses based on examples. | |
USER: A 45-year-old male, fever, cough, fatigue. | |
SYSTEM: [Flu, COVID-19, Pneumonia] | |
USER: A 30-year-old female, severe abdominal pain, nausea. | |
SYSTEM: [Appendicitis, Gallstones, Gastritis] | |
USER: A 10-year-old female, wheezing. | |
SYSTEM: [Asthma, Respiratory Infection] | |
USER: | |
""" | |
system_prompt_cot = """ | |
SYSTEM: | |
You are a medical expert performing differential diagnosis through step-by-step reasoning. | |
Provide intermediate reasoning and final diagnoses. | |
USER: | |
""" | |
system_prompt_tot = """ | |
SYSTEM: | |
You are a medical expert using a tree-of-thought approach for differential diagnosis. | |
Construct a reasoning tree then provide final diagnoses. | |
USER: | |
""" | |
def lcpp_llm(prompt, max_tokens=2048, temperature=0, stop=["USER"]): | |
return llm(prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop) | |
def extract_list(text): | |
match = re.search(r'\[(.*?)\]', text) | |
if match: | |
return [item.strip() for item in match.group(1).split(",")] | |
return [] | |
def determine_most_probable(few_shot, cot, tot): | |
counts = Counter(few_shot + cot + tot) | |
if not counts: | |
return "No Clear Diagnosis" | |
max_occ = max(counts.values()) | |
for diag, cnt in counts.items(): | |
if cnt == max_occ: | |
return diag | |
def medical_diagnosis(symptoms: str): | |
try: | |
# Generate responses | |
resp_few = lcpp_llm(system_prompt_few_shot + symptoms) | |
resp_cot = lcpp_llm(system_prompt_cot + symptoms) | |
resp_tot = lcpp_llm(system_prompt_tot + symptoms) | |
# Extract text | |
text_few = resp_few['choices'][0]['text'].strip() | |
text_cot = resp_cot['choices'][0]['text'].strip() | |
text_tot = resp_tot['choices'][0]['text'].strip() | |
# Parse lists | |
few = extract_list(text_few) | |
cot = extract_list(text_cot) | |
tot = extract_list(text_tot) | |
most = determine_most_probable(few, cot, tot) | |
# Format Markdown output | |
return f""" | |
### Differential Diagnosis Results | |
**Few-Shot Diagnoses:** {', '.join(few) if few else 'No Diagnosis'} | |
**Chain-of-Thought Diagnoses:** {', '.join(cot) if cot else 'No Diagnosis'} | |
**Tree-of-Thought Diagnoses:** {', '.join(tot) if tot else 'No Diagnosis'} | |
**Most Probable Diagnosis:** {most} | |
""" | |
except Exception as e: | |
return f"Error: {e}" | |
# Gradio app definition | |
with gr.Blocks() as demo: | |
gr.Markdown("# Differential Diagnosis Explorer (Local LLaMA)") | |
cond = gr.Textbox(label="Patient Condition", placeholder="A 35-year-old male, fever, wheezing, nausea.") | |
out = gr.Markdown() | |
btn = gr.Button("Diagnose") | |
btn.click(fn=medical_diagnosis, inputs=cond, outputs=out) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv('PORT', 7860))) | |