cperiya's picture
Upload app.py
d86ca36 verified
import streamlit as st
import onnxruntime as ort
from transformers import AutoTokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
# Load ONNX model
session = ort.InferenceSession("DeBERTaNLI.onnx")
# Label mapping for output
label_map = {
0: "Entailment",
1: "Neutral",
2: "Contradiction"
}
# App UI
st.set_page_config(page_title="DeBERTa NLI Inference", page_icon="")
st.title(" DeBERTa-v3 NLI Inference ")
st.write("Predict the relationship between a **Premise** and a **Hypothesis** using a DeBERTa-v3 ONNX model.")
# Input fields
premise = st.text_area("Premise", "A man is playing guitar on stage.")
hypothesis = st.text_area("Hypothesis", "A musician performs for a crowd.")
# Predict button
if st.button("Predict"):
# Tokenize
inputs = tokenizer(
premise,
hypothesis,
return_tensors="np",
padding="max_length",
max_length=128,
truncation=True
)
# Remove token_type_ids if not expected by ONNX model
valid_input_names = set(inp.name for inp in session.get_inputs())
ort_inputs = {k: v for k, v in inputs.items() if k in valid_input_names}
# Run inference
outputs = session.run(None, ort_inputs)
prediction = int(outputs[0].argmax(axis=1)[0])
# Display result
st.success(f"**Prediction:** {label_map[prediction]}")