Spaces:
Sleeping
Sleeping
import streamlit as st | |
import onnxruntime as ort | |
from transformers import AutoTokenizer | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli") | |
# Load ONNX model | |
session = ort.InferenceSession("DeBERTaNLI.onnx") | |
# Label mapping for output | |
label_map = { | |
0: "Entailment", | |
1: "Neutral", | |
2: "Contradiction" | |
} | |
# App UI | |
st.set_page_config(page_title="DeBERTa NLI Inference", page_icon="") | |
st.title(" DeBERTa-v3 NLI Inference ") | |
st.write("Predict the relationship between a **Premise** and a **Hypothesis** using a DeBERTa-v3 ONNX model.") | |
# Input fields | |
premise = st.text_area("Premise", "A man is playing guitar on stage.") | |
hypothesis = st.text_area("Hypothesis", "A musician performs for a crowd.") | |
# Predict button | |
if st.button("Predict"): | |
# Tokenize | |
inputs = tokenizer( | |
premise, | |
hypothesis, | |
return_tensors="np", | |
padding="max_length", | |
max_length=128, | |
truncation=True | |
) | |
# Remove token_type_ids if not expected by ONNX model | |
valid_input_names = set(inp.name for inp in session.get_inputs()) | |
ort_inputs = {k: v for k, v in inputs.items() if k in valid_input_names} | |
# Run inference | |
outputs = session.run(None, ort_inputs) | |
prediction = int(outputs[0].argmax(axis=1)[0]) | |
# Display result | |
st.success(f"**Prediction:** {label_map[prediction]}") | |