Spaces:
Running
Running
File size: 8,612 Bytes
83ff66a 77793f4 b67fca4 77793f4 83ff66a b0565c1 83ff66a 77793f4 83ff66a 77793f4 6ef5bdf 77793f4 83ff66a 77793f4 83ff66a 77793f4 b67fca4 83ff66a 77793f4 b0565c1 a2c1346 cc7489e b67fca4 77793f4 b67fca4 cc7489e b0565c1 6ef5bdf cc7489e 77793f4 6ef5bdf 77793f4 cc7489e 77793f4 cc7489e 77793f4 cc7489e 77793f4 2166c8b 77793f4 b0565c1 77793f4 46668b2 b67fca4 cc7489e bd084e6 b67fca4 46668b2 b47c12e 46668b2 cc7489e b47c12e cc7489e b0565c1 b47c12e b0565c1 46668b2 cc7489e 46668b2 b0565c1 cc7489e 6ef5bdf cc7489e b0565c1 46668b2 b0565c1 cc7489e b0565c1 cc7489e b0565c1 cc7489e b0565c1 6ef5bdf cc7489e 6ef5bdf cc7489e b0565c1 6ef5bdf cc7489e b0565c1 6ef5bdf cc7489e b0565c1 83ff66a 6ef5bdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces
# --- Configuration & Model Loading ---
# Use the pipeline, which is more robust as seen in the working example
print("Loading MedGemma model via pipeline...")
try:
pipe = pipeline(
"image-to-text", # The correct task for this model
model="google/medgemma-4b-it",
model_kwargs={"torch_dtype": torch.bfloat16}, # Pass dtype here
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
model_loaded = False
print(f"Error loading model: {e}")
# --- Core Chatbot Function ---
@spaces.GPU
def symptom_checker_chat(user_input, history_for_display, new_image_upload, image_state):
"""
Manages the conversation using the correct message format derived from the working example.
"""
if not model_loaded:
if user_input:
history_for_display.append((user_input, "Error: The model could not be loaded."))
return history_for_display, image_state, None, ""
current_image = new_image_upload if new_image_upload is not None else image_state
# --- THE CORRECT IMPLEMENTATION ---
# Build the 'messages' list using the exact format from the working X-ray app.
messages = []
# Optional: System prompt can be added here if needed, following the same format.
# Process the conversation history
for user_msg, assistant_msg in history_for_display:
# For history turns, we assume the image was part of the first turn (handled below).
# So, all historical messages are just text.
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
if assistant_msg:
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
# Add the current user turn
current_user_content = [{"type": "text", "text": user_input}]
# If there's an image for the conversation, add it to the first user turn's content
if current_image is not None and not history_for_display: # Only for the very first message
current_user_content.append({"type": "image"}) # The pipeline handles the image object separately
messages.append({"role": "user", "content": current_user_content})
try:
# Generate analysis using the pipeline. It's much simpler.
# We pass the image separately if it exists.
if current_image:
output = pipe(current_image, prompt=messages, generate_kwargs={"max_new_tokens": 512})
else:
# If no image, the pipeline can work with just the prompt
output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512})
# The pipeline's output structure can be complex; we need to extract the final text.
# It's usually in the last dictionary of the generated list.
result = output[0]["generated_text"]
if isinstance(result, list):
# Find the last text content from the model's response
clean_response = next((item['text'] for item in reversed(result) if item['type'] == 'text'), "Sorry, I couldn't generate a response.")
else: # Simpler text-only output
clean_response = result
except Exception as e:
print(f"Caught a critical exception during generation: {e}", flush=True)
clean_response = (
"An error occurred during generation. Details:\n\n"
f"```\n{type(e).__name__}: {e}\n```"
)
history_for_display.append((user_input, clean_response))
return history_for_display, current_image, None, ""
# --- Gradio Interface (Mostly unchanged) ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed.
"""
)
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
def clear_all():
return [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
def on_submit(user_input, display_history, new_image, persisted_image):
if not user_input.strip() and not new_image:
return display_history, persisted_image, None, ""
# The display history IS our history state now
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)
# Generate the response
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
# Decode only the newly generated part
input_token_len = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_token_len:]
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"Caught a critical exception during generation: {e}", flush=True)
# Display the real error in the UI for easier debugging
clean_response = (
"An error occurred during generation. This is the technical details:\n\n"
f"```\n{type(e).__name__}: {e}\n```"
)
# Update the display history
history_for_display.append((user_input, clean_response))
# Return all updated values
return history_for_display, current_image, None, ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed.
"""
)
# State to hold the image across an entire conversation
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
# The history state will now just be for display, a simple list of (text, text) tuples.
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
# The clear function now resets all three states
def clear_all():
return [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(
fn=clear_all,
outputs=[chat_history, image_state, image_box, text_box],
queue=False
)
# The submit handler function
def on_submit(user_input, display_history, new_image, persisted_image):
if not user_input.strip() and not new_image:
return display_history, persisted_image, None, ""
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
# Wire up the events
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)
|