yeswanthvarma commited on
Commit
ff7d72e
·
verified ·
1 Parent(s): 5e571b5

Update utils/xlnet_model.py

Browse files
Files changed (1) hide show
  1. utils/xlnet_model.py +25 -4
utils/xlnet_model.py CHANGED
@@ -6,6 +6,8 @@ from transformers import XLNetModel, XLNetTokenizer
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from huggingface_hub import hf_hub_download
 
 
9
 
10
  # Set Hugging Face cache directory
11
  os.environ["HF_HOME"] = "/tmp/huggingface"
@@ -66,8 +68,27 @@ def fallback_similarity(t1, t2):
66
  w1, w2 = set(t1.lower().split()), set(t2.lower().split())
67
  return round(len(w1 & w2) / len(w1 | w2) * 100) if w1 and w2 else 0
68
 
69
- def get_similarity_score(q, s, r):
 
 
70
  try:
71
- return get_model_prediction(q, s, r) if xlnet_available else tfidf_similarity(s, r)
72
- except Exception:
73
- return fallback_similarity(s, r)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
  from huggingface_hub import hf_hub_download
9
+ from torch.nn.functional import cosine_similarity
10
+
11
 
12
  # Set Hugging Face cache directory
13
  os.environ["HF_HOME"] = "/tmp/huggingface"
 
68
  w1, w2 = set(t1.lower().split()), set(t2.lower().split())
69
  return round(len(w1 & w2) / len(w1 | w2) * 100) if w1 and w2 else 0
70
 
71
+ from torch.nn.functional import cosine_similarity
72
+
73
+ def get_similarity_score(question, student, reference):
74
  try:
75
+ if not xlnet_available:
76
+ raise RuntimeError("XLNet not loaded")
77
+
78
+ def encode(text):
79
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
80
+ with torch.no_grad():
81
+ output = model.xlnet(**inputs).last_hidden_state.mean(dim=1)
82
+ return output
83
+
84
+ student_embed = encode(student)
85
+ reference_embed = encode(reference)
86
+
87
+ sim = cosine_similarity(student_embed, reference_embed).item()
88
+ score = round((sim + 1) / 2 * 100) # Normalize [-1, 1] → [0, 100]
89
+ return score
90
+
91
+ except Exception as e:
92
+ print("❌ Similarity error:", e)
93
+ return 0
94
+