File size: 3,753 Bytes
437a9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import gradio as gr
import re
from collections import Counter
from llama_cpp import Llama
from huggingface_hub import hf_hub_download

# Configuration: set environment variables for model repository and file
HF_REPO_ID = os.getenv("HF_REPO_ID", "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF")
HF_MODEL_BASENAME = os.getenv("HF_MODEL_BASENAME", "Meta-Llama-3-8B-Instruct.Q8_0.gguf")

# Download or locate the quantized LLaMA model
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_BASENAME)

# Initialize LLaMA via llama_cpp
llm = Llama(
    model_path=model_path,
    n_threads=int(os.getenv("LLAMA_THREADS", "4")),
    n_batch=int(os.getenv("LLAMA_BATCH", "256")),
    n_gpu_layers=int(os.getenv("LLAMA_GPU_LAYERS", "43")),
    n_ctx=int(os.getenv("LLAMA_CTX", "8192"))
)

# Prompt templates from notebook
system_prompt_few_shot = """
SYSTEM:
You are an AI medical assistant specializing in differential diagnosis.
Generate the most likely list of diagnoses based on examples.

USER: A 45-year-old male, fever, cough, fatigue.
SYSTEM: [Flu, COVID-19, Pneumonia]

USER: A 30-year-old female, severe abdominal pain, nausea.
SYSTEM: [Appendicitis, Gallstones, Gastritis]

USER: A 10-year-old female, wheezing.
SYSTEM: [Asthma, Respiratory Infection]

USER:
"""
system_prompt_cot = """
SYSTEM:
You are a medical expert performing differential diagnosis through step-by-step reasoning.
Provide intermediate reasoning and final diagnoses.

USER:
"""
system_prompt_tot = """
SYSTEM:
You are a medical expert using a tree-of-thought approach for differential diagnosis.
Construct a reasoning tree then provide final diagnoses.

USER:
"""

def lcpp_llm(prompt, max_tokens=2048, temperature=0, stop=["USER"]):
    return llm(prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)

def extract_list(text):
    match = re.search(r'\[(.*?)\]', text)
    if match:
        return [item.strip() for item in match.group(1).split(",")]
    return []

def determine_most_probable(few_shot, cot, tot):
    counts = Counter(few_shot + cot + tot)
    if not counts:
        return "No Clear Diagnosis"
    max_occ = max(counts.values())
    for diag, cnt in counts.items():
        if cnt == max_occ:
            return diag

def medical_diagnosis(symptoms: str):
    try:
        # Generate responses
        resp_few = lcpp_llm(system_prompt_few_shot + symptoms)
        resp_cot = lcpp_llm(system_prompt_cot + symptoms)
        resp_tot = lcpp_llm(system_prompt_tot + symptoms)

        # Extract text
        text_few = resp_few['choices'][0]['text'].strip()
        text_cot = resp_cot['choices'][0]['text'].strip()
        text_tot = resp_tot['choices'][0]['text'].strip()

        # Parse lists
        few = extract_list(text_few)
        cot = extract_list(text_cot)
        tot = extract_list(text_tot)
        most = determine_most_probable(few, cot, tot)

        # Format Markdown output
        return f"""
### Differential Diagnosis Results

**Few-Shot Diagnoses:** {', '.join(few) if few else 'No Diagnosis'}

**Chain-of-Thought Diagnoses:** {', '.join(cot) if cot else 'No Diagnosis'}

**Tree-of-Thought Diagnoses:** {', '.join(tot) if tot else 'No Diagnosis'}

**Most Probable Diagnosis:** {most}
"""
    except Exception as e:
        return f"Error: {e}"

# Gradio app definition
with gr.Blocks() as demo:
    gr.Markdown("# Differential Diagnosis Explorer (Local LLaMA)")
    cond = gr.Textbox(label="Patient Condition", placeholder="A 35-year-old male, fever, wheezing, nausea.")
    out = gr.Markdown()
    btn = gr.Button("Diagnose")
    btn.click(fn=medical_diagnosis, inputs=cond, outputs=out)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv('PORT', 7860)))