Update utils/xlnet_model.py
Browse files- utils/xlnet_model.py +25 -4
utils/xlnet_model.py
CHANGED
@@ -6,6 +6,8 @@ from transformers import XLNetModel, XLNetTokenizer
|
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
9 |
|
10 |
# Set Hugging Face cache directory
|
11 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
@@ -66,8 +68,27 @@ def fallback_similarity(t1, t2):
|
|
66 |
w1, w2 = set(t1.lower().split()), set(t2.lower().split())
|
67 |
return round(len(w1 & w2) / len(w1 | w2) * 100) if w1 and w2 else 0
|
68 |
|
69 |
-
|
|
|
|
|
70 |
try:
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
+
from torch.nn.functional import cosine_similarity
|
10 |
+
|
11 |
|
12 |
# Set Hugging Face cache directory
|
13 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
|
|
68 |
w1, w2 = set(t1.lower().split()), set(t2.lower().split())
|
69 |
return round(len(w1 & w2) / len(w1 | w2) * 100) if w1 and w2 else 0
|
70 |
|
71 |
+
from torch.nn.functional import cosine_similarity
|
72 |
+
|
73 |
+
def get_similarity_score(question, student, reference):
|
74 |
try:
|
75 |
+
if not xlnet_available:
|
76 |
+
raise RuntimeError("XLNet not loaded")
|
77 |
+
|
78 |
+
def encode(text):
|
79 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
|
80 |
+
with torch.no_grad():
|
81 |
+
output = model.xlnet(**inputs).last_hidden_state.mean(dim=1)
|
82 |
+
return output
|
83 |
+
|
84 |
+
student_embed = encode(student)
|
85 |
+
reference_embed = encode(reference)
|
86 |
+
|
87 |
+
sim = cosine_similarity(student_embed, reference_embed).item()
|
88 |
+
score = round((sim + 1) / 2 * 100) # Normalize [-1, 1] → [0, 100]
|
89 |
+
return score
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
print("❌ Similarity error:", e)
|
93 |
+
return 0
|
94 |
+
|