MedCodeMCP / app.py
gpaasch's picture
managing token limit
6914fb8
import gradio as gr
from utils.model_configuration_utils import select_best_model, ensure_model
from services.llm import build_llm
from utils.voice_input_utils import update_live_transcription, format_response_for_user
from services.embeddings import configure_embeddings
from services.indexing import create_symptom_index
import torchaudio.transforms as T
import re
import logging, sys
import json
from llama_cpp import Llama
logging.basicConfig(stream=sys.stdout, level=logging.INFO, force=True)
logger = logging.getLogger(__name__)
# ========== Model setup ==========
MODEL_NAME, REPO_ID = select_best_model()
model_path = ensure_model()
print(f"Using model: {MODEL_NAME} from {REPO_ID}", flush=True)
print(f"Model path: {model_path}", flush=True)
# ========== LLM initialization ==========
print("\n<<< before build_llm: ", flush=True)
llm = build_llm(model_path)
print(">>> after build_llm", flush=True)
# ========== Embeddings & index setup ==========
print("\n<<< before configure_embeddings: ", flush=True)
configure_embeddings()
print(">>> after configure_embeddings", flush=True)
print("Embeddings configured and ready", flush=True)
print("\n<<< before create_symptom_index: ", flush=True)
symptom_index = create_symptom_index()
print(">>> after create_symptom_index", flush=True)
print("Symptom index built successfully. Ready for queries.", flush=True)
# ========== Prompt template ==========
SYSTEM_PROMPT = (
"You are a medical assistant helping a user find the most relevant ICD-10 code based on their symptoms.\n",
"At each turn, determine the top three most relevant ICD-10 codes based on input from the user.\n",
"Assign a confidence score from 1 to 100 for each code you decided was relevant.\n",
"Asking a question to the user to raise or lower your confidence score for each code.\n",
"Replace low-confidence codes with new ones as you learn more.\n",
"Your goal is to find the most relevant codes with high confidence.\n",
"When you have high confidence in a code, provide it to the user.\n",
"Maintain a conversational tone and explain your reasoning step by step.\n",
"If you need more information, ask the user clarifying questions.\n",
"End your response with a summary of the top codes and their confidence scores.\n",
"If you need to ask the user a follow-up question, do so clearly.\n",
)
def truncate_prompt(prompt, max_tokens=2048):
# Use your model's tokenizer here; this is a placeholder
tokens = prompt.split() # Replace with actual tokenization
if len(tokens) > max_tokens:
tokens = tokens[:max_tokens]
return " ".join(tokens)
# Initialize your model (adjust path and params as needed)
llm = Llama(model_path=model_path)
def truncate_prompt_llama(prompt, max_tokens=2048):
# Tokenize the prompt using llama_cpp's tokenizer
tokens = llm.tokenize(prompt.encode("utf-8"))
if len(tokens) > max_tokens:
# Truncate tokens and decode back to string
tokens = tokens[:max_tokens]
prompt = llm.detokenize(tokens).decode("utf-8", errors="ignore")
return prompt
# ========== Generator handler ==========
def on_submit(symptoms_text, history):
log = []
print("on_submit called", flush=True)
# Placeholder
msg = "πŸ” Received input"
log.append(msg)
print(msg, flush=True)
history = history + [{"role": "assistant", "content": "Processing your request..."}]
yield history, None, "\n".join(log)
# Validate
if not symptoms_text.strip():
msg = "❌ No symptoms provided"
log.append(msg)
print(msg, flush=True)
result = {"error": "No input provided", "diagnoses": [], "confidences": [], "follow_up": []}
yield history, result, "\n".join(log)
return
# Clean input
cleaned = symptoms_text.strip()
msg = f"πŸ”„ Cleaned text: {cleaned}"
log.append(msg)
print(msg, flush=True)
yield history, None, "\n".join(log)
# Semantic query
msg = "πŸ” Running semantic query"
log.append(msg)
print(msg, flush=True)
yield history, None, "\n".join(log)
qe = symptom_index.as_query_engine(retriever_kwargs={"similarity_top_k": 5})
hits = qe.query(cleaned)
msg = f"πŸ” Retrieved context entries"
log.append(msg)
print(msg, flush=True)
history = history + [{"role": "assistant", "content": msg}]
yield history, None, "\n".join(log)
# Build prompt with minimal context
context_list = []
for node in getattr(hits, 'source_nodes', [])[:3]:
md = getattr(node, 'metadata', {}) or {}
context_list.append(f"{md.get('code','')}: {md.get('description','')}")
context_text = "\n".join(context_list)
prompt = "\n".join([
f"{SYSTEM_PROMPT}",
f"User symptoms: '{cleaned}'",
f"Relevant ICD-10 context:\n{context_text}",
])
prompt = truncate_prompt_llama(prompt, max_tokens=2048)
msg = "✏️ Prompt built"
log.append(msg)
print(msg, flush=True)
yield history, None, "\n".join(log)
# Call LLM
response = llm(prompt=prompt)
raw = response
# Extract text from CompletionResponse if needed
if hasattr(raw, "text"):
raw = raw.text
elif hasattr(raw, "content"):
raw = raw.content
# Now ensure it's a dict
if isinstance(raw, str):
try:
raw = json.loads(raw)
except Exception:
raw = {"diagnoses": [], "confidences": [], "follow_up": raw}
assistant_msg = format_response_for_user(raw)
history = history + [{"role": "assistant", "content": assistant_msg}]
msg = "βœ… Final response appended"
log.append(msg)
print(msg, flush=True)
yield history, raw, "\n".join(log)
# ========== Gradio UI ==========
with gr.Blocks(theme="default") as demo:
gr.Markdown("""
# πŸ₯ Medical Symptom to ICD-10 Code Assistant
## Describe symptoms by typing or speaking.
Debug log updates live below.
"""
)
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label="Type your symptoms",
placeholder="I'm feeling under the weather...",
lines=3
)
microphone = gr.Audio(
sources=["microphone"],
streaming=True,
type="numpy",
label="Or speak your symptoms..."
)
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear Chat", variant="secondary")
chatbot = gr.Chatbot(
label="Medical Consultation",
height=500,
type="messages"
)
json_output = gr.JSON(label="Diagnosis JSON")
debug_box = gr.Textbox(label="Debug log", lines=10)
with gr.Column(scale=1):
with gr.Accordion("API Keys (optional)", open=False):
api_key = gr.Textbox(label="OpenAI Key", type="password")
model_selector = gr.Dropdown(
choices=["OpenAI","Modal","Anthropic","MistralAI","Nebius","Hyperbolic","SambaNova"],
value="OpenAI",
label="Model Provider"
)
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature")
# Bindings
submit_btn.click(
fn=on_submit,
inputs=[text_input, chatbot],
outputs=[chatbot, json_output, debug_box],
queue=True
)
clear_btn.click(
lambda: (None, {}, ""),
None,
[chatbot, json_output, debug_box],
queue=False
)
microphone.stream(
fn=update_live_transcription,
inputs=[microphone],
outputs=[text_input],
queue=True
)
# --- About the Creator ---
gr.Markdown("""
---
### πŸ‘‹ About the Creator
Hi! I'm Graham Paasch, an experienced technology professional!
πŸŽ₯ **Check out my YouTube channel** for more tech content:
[Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ)
πŸ’Ό **Looking for a skilled developer?**
I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/)
⭐ If you found this tool helpful, please consider:
- Subscribing to my YouTube channel
- Connecting on LinkedIn
- Sharing this tool with others in healthcare tech
"""
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_api=True, mcp_server=True)