File size: 5,993 Bytes
dc3747b
1f15859
30ca71a
fa0e902
 
398a7eb
a8c3b23
dc3747b
1b0f88d
 
dc3747b
 
 
fa0e902
 
 
dc3747b
 
 
 
 
fa0e902
 
 
 
 
 
 
 
 
398a7eb
 
 
fa0e902
 
398a7eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffe13aa
 
1b0f88d
dc3747b
20e34ca
dc3747b
 
 
 
 
 
 
 
ffe13aa
 
 
398a7eb
eb9f7d3
1f15859
 
dc3747b
 
398a7eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8c3b23
fa0e902
 
 
 
 
 
 
 
 
 
 
 
 
 
ffe13aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb9f7d3
ffe13aa
1b0f88d
ffe13aa
 
 
 
 
398a7eb
 
 
 
fa0e902
 
 
 
 
20e34ca
398a7eb
fa0e902
 
 
 
ffe13aa
 
 
 
 
fa0e902
 
 
 
 
 
398a7eb
fa0e902
a8c3b23
dc3747b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import random
import re

# Load model and tokenizer
# model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load dataset
dataset = load_dataset("openlifescienceai/medmcqa")

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

def get_random_question():
    """Get a random question from the dataset"""
    index = random.randint(0, len(dataset['train']) - 1)
    question_data = dataset['train'][index]
    return (
        question_data['question'],
        question_data['opa'],
        question_data['opb'],
        question_data['opc'],
        question_data['opd'],
        question_data.get('cop', None),  # Correct option (0-3)
        question_data.get('exp', None)   # Explanation
    )

def extract_answer(prediction: str) -> tuple:
    """Extract answer and reasoning from model output"""
    # Try to find the answer part
    answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE)
    answer = answer_match.group(1).upper() if answer_match else "Not found"
    
    # Try to find reasoning part
    reasoning = ""
    if "Reasoning:" in prediction:
        reasoning = prediction.split("Reasoning:")[-1].strip()
    elif "Explanation:" in prediction:
        reasoning = prediction.split("Explanation:")[-1].strip()
    
    return answer, reasoning

def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str, 
           correct_option: int = None, explanation: str = None,
           temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 20):
    # Format the prompt
    prompt = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}\n\nAnswer:"
    
    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            # pad_token_id=tokenizer.eos_token_id
        )
    
    # Get prediction
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
    model_answer, model_reasoning = extract_answer(prediction)
    
    # Format output with evaluation if available
    output = prediction
    if correct_option is not None:
        correct_letter = chr(65 + correct_option)  # Convert 0-3 to A-D
        is_correct = model_answer == correct_letter
        output += f"\n\n---\nEvaluation:\n"
        output += f"Correct Answer: {correct_letter}\n"
        output += f"Model's Answer: {model_answer}\n"
        output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n"
        if explanation:
            output += f"\nExpert Explanation:\n{explanation}"
    
    return output

# Create Gradio interface with Blocks for more control
with gr.Blocks(title="Medical MCQ Predictor") as demo:
    gr.Markdown("# Medical MCQ Predictor")
    gr.Markdown("Get a random medical question or enter your own question and options.")
    
    with gr.Row():
        with gr.Column():
            # Input fields
            question = gr.Textbox(label="Question", lines=3, interactive=True)
            option_a = gr.Textbox(label="Option A", interactive=True)
            option_b = gr.Textbox(label="Option B", interactive=True)
            option_c = gr.Textbox(label="Option C", interactive=True)
            option_d = gr.Textbox(label="Option D", interactive=True)
            
            # Generation parameters
            with gr.Accordion("Generation Parameters", open=False):
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.6,
                    step=0.1,
                    label="Temperature",
                    info="Higher values make output more random, lower values more focused"
                )
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.1,
                    label="Top P",
                    info="Higher values allow more diverse tokens, lower values more focused"
                )
                max_tokens = gr.Slider(
                    minimum=10,
                    maximum=512,
                    value=20,
                    step=32,
                    label="Max Tokens",
                    info="Maximum length of the generated response"
                )
            
            # Hidden fields for correct answer and explanation
            correct_option = gr.Number(visible=False)
            expert_explanation = gr.Textbox(visible=False)
            
            # Buttons
            with gr.Row():
                predict_btn = gr.Button("Predict", variant="primary")
                random_btn = gr.Button("Get Random Question", variant="secondary")
            
            # Output
            output = gr.Textbox(label="Model's Answer", lines=10)
    
    # Set up button actions
    predict_btn.click(
        fn=predict,
        inputs=[
            question, option_a, option_b, option_c, option_d,
            correct_option, expert_explanation,
            temperature, top_p, max_tokens
        ],
        outputs=output
    )
    
    random_btn.click(
        fn=get_random_question,
        inputs=[],
        outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()