File size: 8,607 Bytes
3b5fe24
9d2bec8
 
5e4e457
9d2bec8
5e4e457
3b5fe24
5e4e457
1058d65
448a2cd
c7fb41b
1058d65
 
 
 
 
3b5fe24
5e4e457
9d2bec8
 
5e4e457
 
3b5fe24
5e4e457
 
9d2bec8
5e4e457
3b5fe24
5e4e457
 
 
 
 
 
 
 
 
 
 
 
 
2d413a5
 
f960da5
 
2d413a5
 
f960da5
 
 
 
 
5e4e457
 
c7fb41b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e4e457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28aec0b
 
 
 
 
c7fb41b
8237afe
5e4e457
 
 
 
 
 
6914fb8
a7c84b5
f4dc3d2
 
 
 
 
 
448a2cd
 
 
 
a7c84b5
0272a23
5e4e457
 
 
 
0272a23
5e4e457
 
3b5fe24
 
 
5e4e457
 
 
 
3b5fe24
 
5e4e457
 
 
 
 
 
 
 
 
 
 
 
 
3b5fe24
 
 
5e4e457
3b5fe24
5e4e457
 
3b5fe24
5e4e457
 
 
 
 
 
3b5fe24
5e4e457
3b5fe24
5e4e457
 
 
 
 
 
 
 
 
 
 
 
3b5fe24
 
 
 
5e4e457
3b5fe24
 
 
5e4e457
 
 
 
3b5fe24
5e4e457
3b5fe24
5e4e457
 
3b5fe24
5e4e457
 
3b5fe24
5e4e457
 
 
 
 
3b5fe24
deb0a6d
 
684322f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import gradio as gr
from utils.model_configuration_utils import select_best_model, ensure_model
from services.llm import build_llm
from utils.voice_input_utils import update_live_transcription, format_response_for_user
from services.embeddings import configure_embeddings
from services.indexing import create_symptom_index
import torchaudio.transforms as T
import re
import logging, sys
import json
from llama_cpp import Llama


logging.basicConfig(stream=sys.stdout, level=logging.INFO, force=True)
logger = logging.getLogger(__name__)


# ========== Model setup ==========
MODEL_NAME, REPO_ID = select_best_model()
model_path = ensure_model()
print(f"Using model: {MODEL_NAME} from {REPO_ID}", flush=True)
print(f"Model path: {model_path}", flush=True)

# ========== LLM initialization ==========
print("\n<<< before build_llm: ", flush=True)
llm = build_llm(model_path)
print(">>> after build_llm", flush=True)

# ========== Embeddings & index setup ==========
print("\n<<< before configure_embeddings: ", flush=True)
configure_embeddings()
print(">>> after configure_embeddings", flush=True)
print("Embeddings configured and ready", flush=True)

print("\n<<< before create_symptom_index: ", flush=True)
symptom_index = create_symptom_index()
print(">>> after create_symptom_index", flush=True)
print("Symptom index built successfully. Ready for queries.", flush=True)

# ========== Prompt template ==========
SYSTEM_PROMPT = (
    "You are a medical assistant helping a user find the most relevant ICD-10 code based on their symptoms.\n",
    "At each turn, determine the top three most relevant ICD-10 codes based on input from the user.\n",
    "Assign a confidence score from 1 to 100 for each code you decided was relevant.\n",
    "Asking a question to the user to raise or lower your confidence score for each code.\n",
    "Replace low-confidence codes with new ones as you learn more.\n",
    "Your goal is to find the most relevant codes with high confidence.\n",
    "When you have high confidence in a code, provide it to the user.\n",
    "Maintain a conversational tone and explain your reasoning step by step.\n",
    "If you need more information, ask the user clarifying questions.\n",
    "End your response with a summary of the top codes and their confidence scores.\n",
    "If you need to ask the user a follow-up question, do so clearly.\n",
)

def truncate_prompt(prompt, max_tokens=2048):
    # Use your model's tokenizer here; this is a placeholder
    tokens = prompt.split()  # Replace with actual tokenization
    if len(tokens) > max_tokens:
        tokens = tokens[:max_tokens]
    return " ".join(tokens)

# Initialize your model (adjust path and params as needed)
llm = Llama(model_path=model_path)

def truncate_prompt_llama(prompt, max_tokens=2048):
    # Tokenize the prompt using llama_cpp's tokenizer
    tokens = llm.tokenize(prompt.encode("utf-8"))
    if len(tokens) > max_tokens:
        # Truncate tokens and decode back to string
        tokens = tokens[:max_tokens]
        prompt = llm.detokenize(tokens).decode("utf-8", errors="ignore")
    return prompt

# ========== Generator handler ==========
def on_submit(symptoms_text, history):
    log = []
    print("on_submit called", flush=True)

    # Placeholder
    msg = "πŸ” Received input"
    log.append(msg)
    print(msg, flush=True)
    history = history + [{"role": "assistant", "content": "Processing your request..."}]
    yield history, None, "\n".join(log)

    # Validate
    if not symptoms_text.strip():
        msg = "❌ No symptoms provided"
        log.append(msg)
        print(msg, flush=True)
        result = {"error": "No input provided", "diagnoses": [], "confidences": [], "follow_up": []}
        yield history, result, "\n".join(log)
        return

    # Clean input
    cleaned = symptoms_text.strip()
    msg = f"πŸ”„ Cleaned text: {cleaned}"
    log.append(msg)
    print(msg, flush=True)
    yield history, None, "\n".join(log)

    # Semantic query
    msg = "πŸ” Running semantic query"
    log.append(msg)
    print(msg, flush=True)
    yield history, None, "\n".join(log)

    qe = symptom_index.as_query_engine(retriever_kwargs={"similarity_top_k": 5})
    hits = qe.query(cleaned)
    msg = f"πŸ” Retrieved context entries"
    log.append(msg)
    print(msg, flush=True)
    history = history + [{"role": "assistant", "content": msg}]
    yield history, None, "\n".join(log)

    # Build prompt with minimal context
    context_list = []
    for node in getattr(hits, 'source_nodes', [])[:3]:
        md = getattr(node, 'metadata', {}) or {}
        context_list.append(f"{md.get('code','')}: {md.get('description','')}")
    context_text = "\n".join(context_list)
    prompt = "\n".join([
        f"{SYSTEM_PROMPT}",
        f"User symptoms: '{cleaned}'",
        f"Relevant ICD-10 context:\n{context_text}",
    ])
    prompt = truncate_prompt_llama(prompt, max_tokens=2048)

    msg = "✏️ Prompt built"
    log.append(msg)
    print(msg, flush=True)
    yield history, None, "\n".join(log)

    # Call LLM
    response = llm(prompt=prompt)
    raw = response
    # Extract text from CompletionResponse if needed
    if hasattr(raw, "text"):
        raw = raw.text
    elif hasattr(raw, "content"):
        raw = raw.content
    # Now ensure it's a dict
    if isinstance(raw, str):
        try:
            raw = json.loads(raw)
        except Exception:
            raw = {"diagnoses": [], "confidences": [], "follow_up": raw}
    assistant_msg = format_response_for_user(raw)
    history = history + [{"role": "assistant", "content": assistant_msg}]
    msg = "βœ… Final response appended"
    log.append(msg)
    print(msg, flush=True)
    yield history, raw, "\n".join(log)

# ========== Gradio UI ==========
with gr.Blocks(theme="default") as demo:
    gr.Markdown("""
    # πŸ₯ Medical Symptom to ICD-10 Code Assistant
    ## Describe symptoms by typing or speaking.
    Debug log updates live below.
    """
    )
    with gr.Row():
        with gr.Column(scale=2):
            text_input = gr.Textbox(
                label="Type your symptoms",
                placeholder="I'm feeling under the weather...",
                lines=3
            )
            microphone = gr.Audio(
                sources=["microphone"],
                streaming=True,
                type="numpy",
                label="Or speak your symptoms..."
            )
            submit_btn = gr.Button("Submit", variant="primary")
            clear_btn = gr.Button("Clear Chat", variant="secondary")
            chatbot = gr.Chatbot(
                label="Medical Consultation",
                height=500,
                type="messages"
            )
            json_output = gr.JSON(label="Diagnosis JSON")
            debug_box = gr.Textbox(label="Debug log", lines=10)
        with gr.Column(scale=1):
            with gr.Accordion("API Keys (optional)", open=False):
                api_key = gr.Textbox(label="OpenAI Key", type="password")
                model_selector = gr.Dropdown(
                    choices=["OpenAI","Modal","Anthropic","MistralAI","Nebius","Hyperbolic","SambaNova"],
                    value="OpenAI",
                    label="Model Provider"
                )
                temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature")

    # Bindings
    submit_btn.click(
        fn=on_submit,
        inputs=[text_input, chatbot],
        outputs=[chatbot, json_output, debug_box],
        queue=True
    )
    clear_btn.click(
        lambda: (None, {}, ""),
        None,
        [chatbot, json_output, debug_box],
        queue=False
    )
    microphone.stream(
        fn=update_live_transcription,
        inputs=[microphone],
        outputs=[text_input],
        queue=True
    )

    # --- About the Creator ---
    gr.Markdown("""
    ---
    ### πŸ‘‹ About the Creator

    Hi! I'm Graham Paasch, an experienced technology professional!

    πŸŽ₯ **Check out my YouTube channel** for more tech content:  
    [Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)

    πŸ’Ό **Looking for a skilled developer?**
    I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)

    ⭐ If you found this tool helpful, please consider:
    - Subscribing to my YouTube channel
    - Connecting on LinkedIn
    - Sharing this tool with others in healthcare tech
    """
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_api=True, mcp_server=True)