|
import gradio as gr |
|
from transformers import pipeline |
|
from transformers import AutoTokenizer |
|
|
|
|
|
pipeline_cache = {} |
|
|
|
|
|
def get_model_choices(): |
|
return [ |
|
"UMCU/CardioMedRoBERTa.nl", |
|
"UMCU/CardioBERTa_base.nl", |
|
"UMCU/CardioBERTa.nl_clinical", |
|
"UMCU/CardioDeBERTa.nl", |
|
"UMCU/CardioDeBERTa.nl_clinical", |
|
|
|
"CLTL/MedRoBERTa.nl", |
|
"DTAI-KULeuven/robbert-2023-dutch-base", |
|
"DTAI-KULeuven/robbert-2023-dutch-large", |
|
"joeranbosma/dragon-bert-base-mixed-domain", |
|
"joeranbosma/dragon-bert-base-domain-specific", |
|
"joeranbosma/dragon-roberta-base-mixed-domain", |
|
"joeranbosma/dragon-roberta-large-mixed-domain", |
|
"joeranbosma/dragon-roberta-base-domain-specific", |
|
"joeranbosma/dragon-roberta-large-domain-specific", |
|
"joeranbosma/dragon-longformer-base-mixed-domain", |
|
"joeranbosma/dragon-longformer-large-mixed-domain", |
|
"joeranbosma/dragon-longformer-base-domain-specific", |
|
"joeranbosma/dragon-longformer-large-domain-specific" |
|
] |
|
|
|
|
|
def fill_masked(text: str, model_name: str, top_k: int): |
|
""" |
|
Takes text with [MASK] tokens, a model name, and top_k, returns top predictions. |
|
""" |
|
|
|
if model_name not in pipeline_cache: |
|
pipeline_cache[model_name] = pipeline( |
|
"fill-mask", |
|
model=model_name |
|
) |
|
|
|
fill_mask = pipeline_cache[model_name] |
|
|
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
mask_token = tokenizer.mask_token |
|
text = text.replace("[MASK]", mask_token) |
|
|
|
results = fill_mask(text, top_k=top_k) |
|
|
|
|
|
formatted = [] |
|
for res in results: |
|
formatted.append({ |
|
"sequence": res["sequence"], |
|
"score": round(res["score"], 4), |
|
"token": res["token_str"] |
|
}) |
|
return formatted |
|
|
|
|
|
iface = gr.Interface( |
|
fn=fill_masked, |
|
inputs=[ |
|
gr.Textbox( |
|
lines=2, |
|
placeholder="Type text with [MASK] tokens here...", |
|
label="Masked Text" |
|
), |
|
gr.Dropdown( |
|
choices=get_model_choices(), |
|
value="bert-base-uncased", |
|
label="Model" |
|
), |
|
gr.Slider( |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=5, |
|
label="Top K Predictions" |
|
) |
|
], |
|
outputs=gr.JSON(label="Predictions"), |
|
title="Masked Language Model tester", |
|
description="Enter a sentence with [MASK] tokens, select a model, and choose how many top predictions to return." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|