Spaces:
Sleeping
Sleeping
Update utils/xlnet_model.py
Browse files- utils/xlnet_model.py +25 -45
utils/xlnet_model.py
CHANGED
@@ -1,41 +1,34 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
-
import numpy as np
|
4 |
from torch import nn
|
5 |
from transformers import XLNetModel, XLNetTokenizer
|
6 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
-
from torch.nn.functional import cosine_similarity
|
10 |
-
|
11 |
|
12 |
# Set Hugging Face cache directory
|
13 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
14 |
|
15 |
-
# Download model weights from Hugging Face Hub
|
16 |
MODEL_PATH = hf_hub_download(
|
17 |
repo_id="yeswanthvarma/xlnet-evaluator-model",
|
18 |
filename="xlnet_answer_assessment_model.pt"
|
19 |
)
|
20 |
|
21 |
-
# Define your
|
22 |
class XLNetAnswerAssessmentModel(nn.Module):
|
23 |
def __init__(self):
|
24 |
super().__init__()
|
25 |
self.xlnet = XLNetModel.from_pretrained("xlnet-base-cased")
|
26 |
-
|
27 |
-
self.fc1 = nn.Linear(hidden, 256)
|
28 |
self.fc2 = nn.Linear(256, 64)
|
29 |
-
self.output = nn.Linear(64, 1)
|
30 |
|
31 |
def forward(self, input_ids, attention_mask=None):
|
32 |
-
pooled = self.xlnet(input_ids, attention_mask).last_hidden_state.mean(1)
|
33 |
x = torch.relu(self.fc1(pooled))
|
34 |
x = torch.relu(self.fc2(x))
|
35 |
-
return torch.sigmoid(self.output(x)) #
|
36 |
-
|
37 |
|
38 |
-
#
|
39 |
xlnet_available = False
|
40 |
try:
|
41 |
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
@@ -45,50 +38,37 @@ try:
|
|
45 |
xlnet_available = True
|
46 |
print("✅ Custom XLNet model loaded.")
|
47 |
except Exception as e:
|
48 |
-
print("⚠️ Could not load XLNet model → fallback
|
49 |
|
50 |
# -------------------------------
|
51 |
-
#
|
52 |
# -------------------------------
|
53 |
def get_model_prediction(q, s, r):
|
54 |
if not xlnet_available:
|
55 |
-
raise
|
|
|
|
|
56 |
combined = f"{q} [SEP] {s} [SEP] {r}"
|
57 |
inputs = tokenizer(combined, return_tensors="pt", truncation=True, max_length=512, padding=True)
|
|
|
58 |
with torch.no_grad():
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
61 |
|
62 |
-
|
63 |
-
vec = TfidfVectorizer()
|
64 |
-
mat = vec.fit_transform([t1, t2])
|
65 |
-
return round(cosine_similarity(mat[0], mat[1])[0][0] * 100)
|
66 |
|
|
|
67 |
def fallback_similarity(t1, t2):
|
68 |
w1, w2 = set(t1.lower().split()), set(t2.lower().split())
|
69 |
return round(len(w1 & w2) / len(w1 | w2) * 100) if w1 and w2 else 0
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
def get_similarity_score(question, student, reference):
|
74 |
try:
|
75 |
-
|
76 |
-
raise RuntimeError("XLNet not loaded")
|
77 |
-
|
78 |
-
def encode(text):
|
79 |
-
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
80 |
-
with torch.no_grad():
|
81 |
-
output = model.xlnet(**inputs).last_hidden_state.mean(dim=1)
|
82 |
-
return output
|
83 |
-
|
84 |
-
student_embed = encode(student)
|
85 |
-
reference_embed = encode(reference)
|
86 |
-
|
87 |
-
sim = cosine_similarity(student_embed, reference_embed).item()
|
88 |
-
score = round((sim + 1) / 2 * 100) # Normalize [-1, 1] → [0, 100]
|
89 |
-
return score
|
90 |
-
|
91 |
except Exception as e:
|
92 |
-
print("❌
|
93 |
-
return
|
94 |
-
|
|
|
1 |
import os
|
2 |
import torch
|
|
|
3 |
from torch import nn
|
4 |
from transformers import XLNetModel, XLNetTokenizer
|
|
|
|
|
5 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
6 |
|
7 |
# Set Hugging Face cache directory
|
8 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
9 |
|
10 |
+
# Download trained model weights from Hugging Face Hub
|
11 |
MODEL_PATH = hf_hub_download(
|
12 |
repo_id="yeswanthvarma/xlnet-evaluator-model",
|
13 |
filename="xlnet_answer_assessment_model.pt"
|
14 |
)
|
15 |
|
16 |
+
# Define your trained model architecture
|
17 |
class XLNetAnswerAssessmentModel(nn.Module):
|
18 |
def __init__(self):
|
19 |
super().__init__()
|
20 |
self.xlnet = XLNetModel.from_pretrained("xlnet-base-cased")
|
21 |
+
self.fc1 = nn.Linear(768, 256)
|
|
|
22 |
self.fc2 = nn.Linear(256, 64)
|
23 |
+
self.output = nn.Linear(64, 1)
|
24 |
|
25 |
def forward(self, input_ids, attention_mask=None):
|
26 |
+
pooled = self.xlnet(input_ids, attention_mask).last_hidden_state.mean(dim=1)
|
27 |
x = torch.relu(self.fc1(pooled))
|
28 |
x = torch.relu(self.fc2(x))
|
29 |
+
return torch.sigmoid(self.output(x)) # Output: score in range [0, 1]
|
|
|
30 |
|
31 |
+
# Load tokenizer and model
|
32 |
xlnet_available = False
|
33 |
try:
|
34 |
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
|
|
38 |
xlnet_available = True
|
39 |
print("✅ Custom XLNet model loaded.")
|
40 |
except Exception as e:
|
41 |
+
print("⚠️ Could not load XLNet model → fallback will be used\n", e)
|
42 |
|
43 |
# -------------------------------
|
44 |
+
# Main prediction function
|
45 |
# -------------------------------
|
46 |
def get_model_prediction(q, s, r):
|
47 |
if not xlnet_available:
|
48 |
+
raise RuntimeError("XLNet model not available")
|
49 |
+
|
50 |
+
# Combine input text as during training
|
51 |
combined = f"{q} [SEP] {s} [SEP] {r}"
|
52 |
inputs = tokenizer(combined, return_tensors="pt", truncation=True, max_length=512, padding=True)
|
53 |
+
|
54 |
with torch.no_grad():
|
55 |
+
output = model(
|
56 |
+
input_ids=inputs["input_ids"],
|
57 |
+
attention_mask=inputs["attention_mask"]
|
58 |
+
)
|
59 |
+
score = output.squeeze().item() * 100 # Convert from [0,1] → [0,100]
|
60 |
|
61 |
+
return round(score)
|
|
|
|
|
|
|
62 |
|
63 |
+
# Optional: Fallback similarity using word overlap
|
64 |
def fallback_similarity(t1, t2):
|
65 |
w1, w2 = set(t1.lower().split()), set(t2.lower().split())
|
66 |
return round(len(w1 & w2) / len(w1 | w2) * 100) if w1 and w2 else 0
|
67 |
|
68 |
+
# Final score API (use in app.py)
|
69 |
+
def get_similarity_score(q, s, r):
|
|
|
70 |
try:
|
71 |
+
return get_model_prediction(q, s, r)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
except Exception as e:
|
73 |
+
print("❌ XLNet failed, using fallback:", e)
|
74 |
+
return fallback_similarity(s, r)
|
|