import os
import gradio as gr
import re
from collections import Counter
from llama_cpp import Llama
from huggingface_hub import hf_hub_download

# Configuration: set environment variables for model repository and file
HF_REPO_ID = os.getenv("HF_REPO_ID", "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF")
HF_MODEL_BASENAME = os.getenv("HF_MODEL_BASENAME", "Meta-Llama-3-8B-Instruct.Q8_0.gguf")

# Download or locate the quantized LLaMA model
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_BASENAME)

# Initialize LLaMA via llama_cpp
llm = Llama(
    model_path=model_path,
    n_threads=int(os.getenv("LLAMA_THREADS", "4")),
    n_batch=int(os.getenv("LLAMA_BATCH", "256")),
    n_gpu_layers=int(os.getenv("LLAMA_GPU_LAYERS", "43")),
    n_ctx=int(os.getenv("LLAMA_CTX", "8192"))
)

# Prompt templates from notebook
system_prompt_few_shot = """
SYSTEM:
You are an AI medical assistant specializing in differential diagnosis.
Generate the most likely list of diagnoses based on examples.

USER: A 45-year-old male, fever, cough, fatigue.
SYSTEM: [Flu, COVID-19, Pneumonia]

USER: A 30-year-old female, severe abdominal pain, nausea.
SYSTEM: [Appendicitis, Gallstones, Gastritis]

USER: A 10-year-old female, wheezing.
SYSTEM: [Asthma, Respiratory Infection]

USER:
"""
system_prompt_cot = """
SYSTEM:
You are a medical expert performing differential diagnosis through step-by-step reasoning.
Provide intermediate reasoning and final diagnoses.

USER:
"""
system_prompt_tot = """
SYSTEM:
You are a medical expert using a tree-of-thought approach for differential diagnosis.
Construct a reasoning tree then provide final diagnoses.

USER:
"""

def lcpp_llm(prompt, max_tokens=2048, temperature=0, stop=["USER"]):
    return llm(prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)

def extract_list(text):
    match = re.search(r'\[(.*?)\]', text)
    if match:
        return [item.strip() for item in match.group(1).split(",")]
    return []

def determine_most_probable(few_shot, cot, tot):
    counts = Counter(few_shot + cot + tot)
    if not counts:
        return "No Clear Diagnosis"
    max_occ = max(counts.values())
    for diag, cnt in counts.items():
        if cnt == max_occ:
            return diag

def medical_diagnosis(symptoms: str):
    try:
        # Generate responses
        resp_few = lcpp_llm(system_prompt_few_shot + symptoms)
        resp_cot = lcpp_llm(system_prompt_cot + symptoms)
        resp_tot = lcpp_llm(system_prompt_tot + symptoms)

        # Extract text
        text_few = resp_few['choices'][0]['text'].strip()
        text_cot = resp_cot['choices'][0]['text'].strip()
        text_tot = resp_tot['choices'][0]['text'].strip()

        # Parse lists
        few = extract_list(text_few)
        cot = extract_list(text_cot)
        tot = extract_list(text_tot)
        most = determine_most_probable(few, cot, tot)

        # Format Markdown output
        return f"""
### Differential Diagnosis Results

**Few-Shot Diagnoses:** {', '.join(few) if few else 'No Diagnosis'}

**Chain-of-Thought Diagnoses:** {', '.join(cot) if cot else 'No Diagnosis'}

**Tree-of-Thought Diagnoses:** {', '.join(tot) if tot else 'No Diagnosis'}

**Most Probable Diagnosis:** {most}
"""
    except Exception as e:
        return f"Error: {e}"

# Gradio app definition
with gr.Blocks() as demo:
    gr.Markdown("# Differential Diagnosis Explorer (Local LLaMA)")
    cond = gr.Textbox(label="Patient Condition", placeholder="A 35-year-old male, fever, wheezing, nausea.")
    out = gr.Markdown()
    btn = gr.Button("Diagnose")
    btn.click(fn=medical_diagnosis, inputs=cond, outputs=out)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv('PORT', 7860)))