kwanpon's picture
Update app.py
3baf600 verified
import torch
import gradio as gr
import pandas as pd
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer
)
# load dataset
df = pd.read_csv("dataset.csv")
dataset = Dataset.from_pandas(df)
# load tokenizer & model
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
# tokenize data
def preprocess(examples):
return tokenizer(examples["text"], truncation=True, padding=True)
tokenized_dataset = dataset.map(preprocess, batched=True)
# training arguments
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=4,
num_train_epochs=3,
logging_steps=10,
save_strategy="no",
learning_rate=2e-5,
)
# train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
)
trainer.train()
# inference function for gradio
def classify(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).numpy()[0]
return {
"ไม่เกี่ยวข้อง": float(probs[0]),
"จ้างงานรถขนส่ง": float(probs[1]),
}
# gradio interface
demo = gr.Interface(
fn=classify,
inputs=gr.Textbox(lines=3, label="ข้อความ"),
outputs=gr.Label(label="ผลการจำแนก"),
title="Text Classifier: Zero-Shot NLI",
description="กรุณาพิมพ์ข้อความเพื่อตรวจสอบว่าเป็นการว่าจ้างงานรถขนส่งหรือไม่"
)
if __name__ == "__main__":
demo.launch()