Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from datasets import load_dataset | |
import random | |
import re | |
# Load model and tokenizer | |
# model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B" | |
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA" | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load dataset | |
dataset = load_dataset("openlifescienceai/medmcqa") | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
def get_random_question(): | |
"""Get a random question from the dataset""" | |
index = random.randint(0, len(dataset['train']) - 1) | |
question_data = dataset['train'][index] | |
return ( | |
question_data['question'], | |
question_data['opa'], | |
question_data['opb'], | |
question_data['opc'], | |
question_data['opd'], | |
question_data.get('cop', None), # Correct option (0-3) | |
question_data.get('exp', None) # Explanation | |
) | |
def extract_answer(prediction: str) -> tuple: | |
"""Extract answer and reasoning from model output""" | |
# Try to find the answer part | |
answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE) | |
answer = answer_match.group(1).upper() if answer_match else "Not found" | |
# Try to find reasoning part | |
reasoning = "" | |
if "Reasoning:" in prediction: | |
reasoning = prediction.split("Reasoning:")[-1].strip() | |
elif "Explanation:" in prediction: | |
reasoning = prediction.split("Explanation:")[-1].strip() | |
return answer, reasoning | |
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str, | |
correct_option: int = None, explanation: str = None, | |
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 20): | |
# Format the prompt | |
prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:" | |
# Tokenize and generate | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
# pad_token_id=tokenizer.eos_token_id | |
) | |
# Get prediction | |
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
model_answer, model_reasoning = extract_answer(prediction) | |
# Format output with evaluation if available | |
output = prediction | |
if correct_option is not None: | |
correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D | |
is_correct = model_answer == correct_letter | |
output += f"\n\n---\nEvaluation:\n" | |
output += f"Correct Answer: {correct_letter}\n" | |
output += f"Model's Answer: {model_answer}\n" | |
output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n" | |
if explanation: | |
output += f"\nExpert Explanation:\n{explanation}" | |
return output | |
# Create Gradio interface with Blocks for more control | |
with gr.Blocks(title="Medical MCQ Predictor") as demo: | |
gr.Markdown("# Medical MCQ Predictor") | |
gr.Markdown("Get a random medical question or enter your own question and options.") | |
with gr.Row(): | |
with gr.Column(): | |
# Input fields | |
question = gr.Textbox(label="Question", lines=3, interactive=True) | |
option_a = gr.Textbox(label="Option A", interactive=True) | |
option_b = gr.Textbox(label="Option B", interactive=True) | |
option_c = gr.Textbox(label="Option C", interactive=True) | |
option_d = gr.Textbox(label="Option D", interactive=True) | |
# Generation parameters | |
with gr.Accordion("Generation Parameters", open=False): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.6, | |
step=0.1, | |
label="Temperature", | |
info="Higher values make output more random, lower values more focused" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
label="Top P", | |
info="Higher values allow more diverse tokens, lower values more focused" | |
) | |
max_tokens = gr.Slider( | |
minimum=10, | |
maximum=512, | |
value=20, | |
step=32, | |
label="Max Tokens", | |
info="Maximum length of the generated response" | |
) | |
# Hidden fields for correct answer and explanation | |
correct_option = gr.Number(visible=False) | |
expert_explanation = gr.Textbox(visible=False) | |
# Buttons | |
with gr.Row(): | |
predict_btn = gr.Button("Predict", variant="primary") | |
random_btn = gr.Button("Get Random Question", variant="secondary") | |
# Output | |
output = gr.Textbox(label="Model's Answer", lines=10) | |
# Set up button actions | |
predict_btn.click( | |
fn=predict, | |
inputs=[ | |
question, option_a, option_b, option_c, option_d, | |
correct_option, expert_explanation, | |
temperature, top_p, max_tokens | |
], | |
outputs=output | |
) | |
random_btn.click( | |
fn=get_random_question, | |
inputs=[], | |
outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |