File size: 1,446 Bytes
76c0b3d
 
 
 
52f712f
76c0b3d
 
 
 
 
52f712f
76c0b3d
 
 
 
 
 
52f712f
 
d86ca36
52f712f
76c0b3d
52f712f
 
 
76c0b3d
52f712f
76c0b3d
52f712f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76c0b3d
 
52f712f
 
76c0b3d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import streamlit as st
import onnxruntime as ort
from transformers import AutoTokenizer

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")

# Load ONNX model
session = ort.InferenceSession("DeBERTaNLI.onnx")

# Label mapping for output
label_map = {
    0: "Entailment",
    1: "Neutral",
    2: "Contradiction"
}

# App UI
st.set_page_config(page_title="DeBERTa NLI Inference", page_icon="")
st.title(" DeBERTa-v3 NLI Inference ")
st.write("Predict the relationship between a **Premise** and a **Hypothesis** using a DeBERTa-v3 ONNX model.")

# Input fields
premise = st.text_area("Premise", "A man is playing guitar on stage.")
hypothesis = st.text_area("Hypothesis", "A musician performs for a crowd.")

# Predict button
if st.button("Predict"):
    # Tokenize
    inputs = tokenizer(
        premise,
        hypothesis,
        return_tensors="np",
        padding="max_length",
        max_length=128,
        truncation=True
    )

    # Remove token_type_ids if not expected by ONNX model
    valid_input_names = set(inp.name for inp in session.get_inputs())
    ort_inputs = {k: v for k, v in inputs.items() if k in valid_input_names}

    # Run inference
    outputs = session.run(None, ort_inputs)
    prediction = int(outputs[0].argmax(axis=1)[0])

    # Display result
    st.success(f"**Prediction:** {label_map[prediction]}")