Spaces:
Running
Running
File size: 4,726 Bytes
83ff66a 77793f4 b67fca4 77793f4 83ff66a b0565c1 83ff66a 77793f4 eef3b89 83ff66a 77793f4 eef3b89 77793f4 eef3b89 6ef5bdf 77793f4 83ff66a 77793f4 83ff66a b67fca4 83ff66a 77793f4 b0565c1 a2c1346 cc7489e b67fca4 77793f4 b67fca4 cc7489e b0565c1 6ef5bdf cc7489e eef3b89 77793f4 eef3b89 77793f4 eef3b89 77793f4 eef3b89 77793f4 eef3b89 77793f4 eef3b89 cc7489e 77793f4 eef3b89 cc7489e eef3b89 77793f4 cc7489e 77793f4 2166c8b b47c12e b0565c1 46668b2 cc7489e 46668b2 b0565c1 6ef5bdf cc7489e b0565c1 eef3b89 b0565c1 cc7489e b0565c1 eef3b89 6ef5bdf cc7489e 6ef5bdf eef3b89 b0565c1 6ef5bdf cc7489e b0565c1 6ef5bdf cc7489e b0565c1 83ff66a 6ef5bdf eef3b89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces
# --- Configuration & Model Loading ---
# Use the pipeline, which is more robust as seen in the working example
print("Loading MedGemma model via pipeline...")
model_loaded = False
pipe = None
try:
pipe = pipeline(
"image-to-text",
model="google/medgemma-4b-it",
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# --- Core Chatbot Function ---
@spaces.GPU
def symptom_checker_chat(user_input, history_for_display, new_image_upload, image_state):
"""
Manages the conversation using the correct message format derived from the working example.
"""
if not model_loaded:
if user_input:
history_for_display.append((user_input, "Error: The model could not be loaded."))
return history_for_display, image_state, None, ""
current_image = new_image_upload if new_image_upload is not None else image_state
# Build the 'messages' list using the correct format for the pipeline
messages = []
# Process the conversation history
for user_msg, assistant_msg in history_for_display:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Add the current user turn
messages.append({"role": "user", "content": user_input})
try:
# The pipeline call is simpler. We pass the image as the main argument
# and the text conversation as the `prompt`.
if current_image:
# The image goes first, the prompt kwarg contains the conversation history
output = pipe(current_image, prompt=messages, generate_kwargs={"max_new_tokens": 512, "do_sample": True, "temperature": 0.7})
else:
# If no image, the pipeline can work with just the prompt
output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512, "do_sample": True, "temperature": 0.7})
# The pipeline output structure contains the full conversation.
# We want the content of the last message, which is the model's reply.
clean_response = output[0]["generated_text"][-1]['content']
except Exception as e:
print(f"Caught a critical exception during generation: {e}", flush=True)
clean_response = (
"An error occurred during generation. Details:\n\n"
f"```\n{type(e).__name__}: {e}\n```"
)
history_for_display.append((user_input, clean_response))
return history_for_display, current_image, None, ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed.
"""
)
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm that is red and itchy...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
def clear_all():
return [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
def on_submit(user_input, display_history, new_image, persisted_image):
if not user_input.strip() and not new_image:
return display_history, persisted_image, None, ""
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
# Event Handlers for submit button and enter key
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)
|