Spaces:
Running
Running
File size: 6,705 Bytes
83ff66a 77793f4 b67fca4 77793f4 83ff66a b0565c1 83ff66a 77793f4 eef3b89 83ff66a f48ab42 77793f4 eef3b89 77793f4 eef3b89 6ef5bdf 77793f4 83ff66a 77793f4 83ff66a b67fca4 83ff66a 77793f4 b0565c1 f48ab42 cc7489e b67fca4 f48ab42 b67fca4 cc7489e b0565c1 6ef5bdf cc7489e 5296e2a f48ab42 77793f4 f48ab42 77793f4 f48ab42 5296e2a f48ab42 77793f4 f48ab42 cc7489e f48ab42 eef3b89 77793f4 cc7489e f48ab42 5296e2a 77793f4 2166c8b b47c12e b0565c1 46668b2 5296e2a 46668b2 b0565c1 6ef5bdf 5296e2a b0565c1 eef3b89 b0565c1 cc7489e b0565c1 eef3b89 6ef5bdf cc7489e 5296e2a 6ef5bdf b0565c1 6ef5bdf cc7489e b0565c1 6ef5bdf cc7489e b0565c1 83ff66a 6ef5bdf eef3b89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces
# --- Configuration & Model Loading ---
print("Loading MedGemma model via pipeline...")
model_loaded = False
pipe = None
try:
# Using the "image-to-text" pipeline is the standard for these models
pipe = pipeline(
"image-to-text",
model="google/medgemma-4b-it",
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# --- Core Chatbot Function ---
@spaces.GPU(duration=120)
def symptom_checker_chat(user_input, history_for_display, new_image_upload, image_state):
"""
Manages the conversation by correctly separating the image object from the
text-based message history in the pipeline call.
"""
if not model_loaded:
if user_input:
history_for_display.append((user_input, "Error: The model could not be loaded."))
return history_for_display, image_state, None, ""
current_image = new_image_upload if new_image_upload is not None else image_state
# --- THE CORRECT IMPLEMENTATION ---
# 1. Build a simple list of text messages for the conversation history.
# The `messages` list should NOT contain any image objects.
messages = []
for user_msg, assistant_msg in history_for_display:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add the current user's text message
messages.append({"role": "user", "content": user_input})
try:
# 2. Call the pipeline differently based on whether an image is present.
generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "temperature": 0.7}
if current_image:
# For multimodal calls, the image is the FIRST argument,
# and the text conversation is passed to the `prompt` keyword argument.
output = pipe(current_image, prompt=messages, generate_kwargs=generate_kwargs)
else:
# For text-only calls, we ONLY use the `prompt` keyword argument.
output = pipe(prompt=messages, generate_kwargs=generate_kwargs)
# 3. Extract the response. The pipeline returns the full conversation.
# The last message is the model's new reply.
clean_response = output[0]["generated_text"][-1]['content']
except Exception as e:
print(f"Caught a critical exception during generation: {e}", flush=True)
clean_response = (
"An error occurred during generation. Details:\n\n"
f"```\n{type(e).__name__}: {e}\n```"
)
# Update history and return values for the Gradio UI
history_for_display.append((user_input, clean_response))
return history_for_display, current_image, None, ""
# --- Gradio Interface (No changes needed here) ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image.
"""
)
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm that is red and itchy...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
def clear_all():
return [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
def on_submit(user_input, display_history, new_image, persisted_image):
if not user_input.strip() and not new_image:
return display_history, persisted_image, None, ""
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)
# Update history and return values for Gradio UI
history_for_display.append((user_input, clean_response))
return history_for_display, current_image, None, ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image.
"""
)
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm that is red and itchy...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
def clear_all():
return [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
def on_submit(user_input, display_history, new_image, persisted_image):
if not user_input.strip() and not new_image:
return display_history, persisted_image, None, ""
return symptom_checker_chat(user_input, display_history, new__image, persisted_image)
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)
|