BioXP-0.5b-v2 / app.py
Abaryan
Update app.py
1b0f88d verified
raw
history blame
5.99 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import random
import re
# Load model and tokenizer
# model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load dataset
dataset = load_dataset("openlifescienceai/medmcqa")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
def get_random_question():
"""Get a random question from the dataset"""
index = random.randint(0, len(dataset['train']) - 1)
question_data = dataset['train'][index]
return (
question_data['question'],
question_data['opa'],
question_data['opb'],
question_data['opc'],
question_data['opd'],
question_data.get('cop', None), # Correct option (0-3)
question_data.get('exp', None) # Explanation
)
def extract_answer(prediction: str) -> tuple:
"""Extract answer and reasoning from model output"""
# Try to find the answer part
answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE)
answer = answer_match.group(1).upper() if answer_match else "Not found"
# Try to find reasoning part
reasoning = ""
if "Reasoning:" in prediction:
reasoning = prediction.split("Reasoning:")[-1].strip()
elif "Explanation:" in prediction:
reasoning = prediction.split("Explanation:")[-1].strip()
return answer, reasoning
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
correct_option: int = None, explanation: str = None,
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 20):
# Format the prompt
prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
# pad_token_id=tokenizer.eos_token_id
)
# Get prediction
prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
model_answer, model_reasoning = extract_answer(prediction)
# Format output with evaluation if available
output = prediction
if correct_option is not None:
correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
is_correct = model_answer == correct_letter
output += f"\n\n---\nEvaluation:\n"
output += f"Correct Answer: {correct_letter}\n"
output += f"Model's Answer: {model_answer}\n"
output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n"
if explanation:
output += f"\nExpert Explanation:\n{explanation}"
return output
# Create Gradio interface with Blocks for more control
with gr.Blocks(title="Medical MCQ Predictor") as demo:
gr.Markdown("# Medical MCQ Predictor")
gr.Markdown("Get a random medical question or enter your own question and options.")
with gr.Row():
with gr.Column():
# Input fields
question = gr.Textbox(label="Question", lines=3, interactive=True)
option_a = gr.Textbox(label="Option A", interactive=True)
option_b = gr.Textbox(label="Option B", interactive=True)
option_c = gr.Textbox(label="Option C", interactive=True)
option_d = gr.Textbox(label="Option D", interactive=True)
# Generation parameters
with gr.Accordion("Generation Parameters", open=False):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.1,
label="Temperature",
info="Higher values make output more random, lower values more focused"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top P",
info="Higher values allow more diverse tokens, lower values more focused"
)
max_tokens = gr.Slider(
minimum=10,
maximum=512,
value=20,
step=32,
label="Max Tokens",
info="Maximum length of the generated response"
)
# Hidden fields for correct answer and explanation
correct_option = gr.Number(visible=False)
expert_explanation = gr.Textbox(visible=False)
# Buttons
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
random_btn = gr.Button("Get Random Question", variant="secondary")
# Output
output = gr.Textbox(label="Model's Answer", lines=10)
# Set up button actions
predict_btn.click(
fn=predict,
inputs=[
question, option_a, option_b, option_c, option_d,
correct_option, expert_explanation,
temperature, top_p, max_tokens
],
outputs=output
)
random_btn.click(
fn=get_random_question,
inputs=[],
outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
)
# Launch the app
if __name__ == "__main__":
demo.launch()