File size: 2,579 Bytes
3ec911c
c428c79
 
5888c6a
a47415e
 
c428c79
 
 
cd37260
a47415e
 
 
 
 
cd37260
5888c6a
 
 
 
cd37260
5888c6a
cd37260
5888c6a
 
cd37260
5888c6a
 
cd37260
5888c6a
cd37260
c428c79
5888c6a
 
 
 
 
 
 
 
cd37260
5888c6a
c428c79
cd37260
c428c79
5888c6a
 
cd37260
 
 
5888c6a
 
cd37260
5888c6a
cd37260
 
 
 
 
5888c6a
cd37260
5888c6a
cd37260
5888c6a
 
 
 
cd37260
 
5888c6a
cd37260
ff7d72e
cd37260
 
7963a78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import torch
from torch import nn
from transformers import XLNetModel, XLNetTokenizer
from huggingface_hub import hf_hub_download

# Set Hugging Face cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"

# Download trained model weights from Hugging Face Hub
MODEL_PATH = hf_hub_download(
    repo_id="yeswanthvarma/xlnet-evaluator-model",
    filename="xlnet_answer_assessment_model.pt"
)

# Define your trained model architecture
class XLNetAnswerAssessmentModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.xlnet = XLNetModel.from_pretrained("xlnet-base-cased")
        self.fc1 = nn.Linear(768, 256)
        self.fc2 = nn.Linear(256, 64)
        self.output = nn.Linear(64, 1)

    def forward(self, input_ids, attention_mask=None):
        pooled = self.xlnet(input_ids, attention_mask).last_hidden_state.mean(dim=1)
        x = torch.relu(self.fc1(pooled))
        x = torch.relu(self.fc2(x))
        return torch.sigmoid(self.output(x))  # Output: score in range [0, 1]

# Load tokenizer and model
xlnet_available = False
try:
    tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
    model = XLNetAnswerAssessmentModel()
    model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
    model.eval()
    xlnet_available = True
    print("✅ Custom XLNet model loaded.")
except Exception as e:
    print("⚠️  Could not load XLNet model → fallback will be used\n", e)

# -------------------------------
# Main prediction function
# -------------------------------
def get_model_prediction(q, s, r):
    if not xlnet_available:
        raise RuntimeError("XLNet model not available")

    # Combine input text as during training
    combined = f"{q} [SEP] {s} [SEP] {r}"
    inputs = tokenizer(combined, return_tensors="pt", truncation=True, max_length=512, padding=True)

    with torch.no_grad():
        output = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )
        score = output.squeeze().item() * 100  # Convert from [0,1] → [0,100]

    return round(score)

# Optional: Fallback similarity using word overlap
def fallback_similarity(t1, t2):
    w1, w2 = set(t1.lower().split()), set(t2.lower().split())
    return round(len(w1 & w2) / len(w1 | w2) * 100) if w1 and w2 else 0

# Final score API (use in app.py)
def get_similarity_score(q, s, r):
    try:
        return get_model_prediction(q, s, r)
    except Exception as e:
        print("❌ XLNet failed, using fallback:", e)
        return fallback_similarity(s, r)