Spaces:
Running
Running
import gradio as gr | |
from utils.model_configuration_utils import select_best_model, ensure_model | |
from services.llm import build_llm | |
from utils.voice_input_utils import update_live_transcription, format_response_for_user | |
from services.embeddings import configure_embeddings | |
from services.indexing import create_symptom_index | |
import torchaudio.transforms as T | |
import re | |
import logging, sys | |
import json | |
from llama_cpp import Llama | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO, force=True) | |
logger = logging.getLogger(__name__) | |
# ========== Model setup ========== | |
MODEL_NAME, REPO_ID = select_best_model() | |
model_path = ensure_model() | |
print(f"Using model: {MODEL_NAME} from {REPO_ID}", flush=True) | |
print(f"Model path: {model_path}", flush=True) | |
# ========== LLM initialization ========== | |
print("\n<<< before build_llm: ", flush=True) | |
llm = build_llm(model_path) | |
print(">>> after build_llm", flush=True) | |
# ========== Embeddings & index setup ========== | |
print("\n<<< before configure_embeddings: ", flush=True) | |
configure_embeddings() | |
print(">>> after configure_embeddings", flush=True) | |
print("Embeddings configured and ready", flush=True) | |
print("\n<<< before create_symptom_index: ", flush=True) | |
symptom_index = create_symptom_index() | |
print(">>> after create_symptom_index", flush=True) | |
print("Symptom index built successfully. Ready for queries.", flush=True) | |
# ========== Prompt template ========== | |
SYSTEM_PROMPT = ( | |
"You are a medical assistant helping a user find the most relevant ICD-10 code based on their symptoms.\n", | |
"At each turn, determine the top three most relevant ICD-10 codes based on input from the user.\n", | |
"Assign a confidence score from 1 to 100 for each code you decided was relevant.\n", | |
"Asking a question to the user to raise or lower your confidence score for each code.\n", | |
"Replace low-confidence codes with new ones as you learn more.\n", | |
"Your goal is to find the most relevant codes with high confidence.\n", | |
"When you have high confidence in a code, provide it to the user.\n", | |
"Maintain a conversational tone and explain your reasoning step by step.\n", | |
"If you need more information, ask the user clarifying questions.\n", | |
"End your response with a summary of the top codes and their confidence scores.\n", | |
"If you need to ask the user a follow-up question, do so clearly.\n", | |
) | |
def truncate_prompt(prompt, max_tokens=2048): | |
# Use your model's tokenizer here; this is a placeholder | |
tokens = prompt.split() # Replace with actual tokenization | |
if len(tokens) > max_tokens: | |
tokens = tokens[:max_tokens] | |
return " ".join(tokens) | |
# Initialize your model (adjust path and params as needed) | |
llm = Llama(model_path=model_path) | |
def truncate_prompt_llama(prompt, max_tokens=2048): | |
# Tokenize the prompt using llama_cpp's tokenizer | |
tokens = llm.tokenize(prompt.encode("utf-8")) | |
if len(tokens) > max_tokens: | |
# Truncate tokens and decode back to string | |
tokens = tokens[:max_tokens] | |
prompt = llm.detokenize(tokens).decode("utf-8", errors="ignore") | |
return prompt | |
# ========== Generator handler ========== | |
def on_submit(symptoms_text, history): | |
log = [] | |
print("on_submit called", flush=True) | |
# Placeholder | |
msg = "π Received input" | |
log.append(msg) | |
print(msg, flush=True) | |
history = history + [{"role": "assistant", "content": "Processing your request..."}] | |
yield history, None, "\n".join(log) | |
# Validate | |
if not symptoms_text.strip(): | |
msg = "β No symptoms provided" | |
log.append(msg) | |
print(msg, flush=True) | |
result = {"error": "No input provided", "diagnoses": [], "confidences": [], "follow_up": []} | |
yield history, result, "\n".join(log) | |
return | |
# Clean input | |
cleaned = symptoms_text.strip() | |
msg = f"π Cleaned text: {cleaned}" | |
log.append(msg) | |
print(msg, flush=True) | |
yield history, None, "\n".join(log) | |
# Semantic query | |
msg = "π Running semantic query" | |
log.append(msg) | |
print(msg, flush=True) | |
yield history, None, "\n".join(log) | |
qe = symptom_index.as_query_engine(retriever_kwargs={"similarity_top_k": 5}) | |
hits = qe.query(cleaned) | |
msg = f"π Retrieved context entries" | |
log.append(msg) | |
print(msg, flush=True) | |
history = history + [{"role": "assistant", "content": msg}] | |
yield history, None, "\n".join(log) | |
# Build prompt with minimal context | |
context_list = [] | |
for node in getattr(hits, 'source_nodes', [])[:3]: | |
md = getattr(node, 'metadata', {}) or {} | |
context_list.append(f"{md.get('code','')}: {md.get('description','')}") | |
context_text = "\n".join(context_list) | |
prompt = "\n".join([ | |
f"{SYSTEM_PROMPT}", | |
f"User symptoms: '{cleaned}'", | |
f"Relevant ICD-10 context:\n{context_text}", | |
]) | |
prompt = truncate_prompt_llama(prompt, max_tokens=2048) | |
msg = "βοΈ Prompt built" | |
log.append(msg) | |
print(msg, flush=True) | |
yield history, None, "\n".join(log) | |
# Call LLM | |
response = llm(prompt=prompt) | |
raw = response | |
# Extract text from CompletionResponse if needed | |
if hasattr(raw, "text"): | |
raw = raw.text | |
elif hasattr(raw, "content"): | |
raw = raw.content | |
# Now ensure it's a dict | |
if isinstance(raw, str): | |
try: | |
raw = json.loads(raw) | |
except Exception: | |
raw = {"diagnoses": [], "confidences": [], "follow_up": raw} | |
assistant_msg = format_response_for_user(raw) | |
history = history + [{"role": "assistant", "content": assistant_msg}] | |
msg = "β Final response appended" | |
log.append(msg) | |
print(msg, flush=True) | |
yield history, raw, "\n".join(log) | |
# ========== Gradio UI ========== | |
with gr.Blocks(theme="default") as demo: | |
gr.Markdown(""" | |
# π₯ Medical Symptom to ICD-10 Code Assistant | |
## Describe symptoms by typing or speaking. | |
Debug log updates live below. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_input = gr.Textbox( | |
label="Type your symptoms", | |
placeholder="I'm feeling under the weather...", | |
lines=3 | |
) | |
microphone = gr.Audio( | |
sources=["microphone"], | |
streaming=True, | |
type="numpy", | |
label="Or speak your symptoms..." | |
) | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
chatbot = gr.Chatbot( | |
label="Medical Consultation", | |
height=500, | |
type="messages" | |
) | |
json_output = gr.JSON(label="Diagnosis JSON") | |
debug_box = gr.Textbox(label="Debug log", lines=10) | |
with gr.Column(scale=1): | |
with gr.Accordion("API Keys (optional)", open=False): | |
api_key = gr.Textbox(label="OpenAI Key", type="password") | |
model_selector = gr.Dropdown( | |
choices=["OpenAI","Modal","Anthropic","MistralAI","Nebius","Hyperbolic","SambaNova"], | |
value="OpenAI", | |
label="Model Provider" | |
) | |
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature") | |
# Bindings | |
submit_btn.click( | |
fn=on_submit, | |
inputs=[text_input, chatbot], | |
outputs=[chatbot, json_output, debug_box], | |
queue=True | |
) | |
clear_btn.click( | |
lambda: (None, {}, ""), | |
None, | |
[chatbot, json_output, debug_box], | |
queue=False | |
) | |
microphone.stream( | |
fn=update_live_transcription, | |
inputs=[microphone], | |
outputs=[text_input], | |
queue=True | |
) | |
# --- About the Creator --- | |
gr.Markdown(""" | |
--- | |
### π About the Creator | |
Hi! I'm Graham Paasch, an experienced technology professional! | |
π₯ **Check out my YouTube channel** for more tech content: | |
[Subscribe to my channel](https://www.youtube.com/channel/UCg3oUjrSYcqsL9rGk1g_lPQ) | |
πΌ **Looking for a skilled developer?** | |
I'm currently seeking new opportunities! View my experience and connect on [LinkedIn](https://www.linkedin.com/in/grahampaasch/) | |
β If you found this tool helpful, please consider: | |
- Subscribing to my YouTube channel | |
- Connecting on LinkedIn | |
- Sharing this tool with others in healthcare tech | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_api=True, mcp_server=True) | |