Update utils/xlnet_model.py
Browse files- utils/xlnet_model.py +14 -40
utils/xlnet_model.py
CHANGED
@@ -1,48 +1,22 @@
|
|
1 |
-
import requests, tqdm, torch, numpy as np
|
2 |
-
from torch import nn
|
3 |
import os
|
4 |
-
|
|
|
|
|
5 |
from transformers import XLNetModel, XLNetTokenizer
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
8 |
-
|
9 |
-
# ------------------------------------------------------------------
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
-
import torch
|
12 |
|
13 |
-
#
|
|
|
|
|
|
|
14 |
MODEL_PATH = hf_hub_download(
|
15 |
repo_id="yeswanthvarma/xlnet-evaluator-model",
|
16 |
filename="xlnet_answer_assessment_model.pt"
|
17 |
)
|
18 |
|
19 |
-
#
|
20 |
-
# Load tokenizer and model architecture
|
21 |
-
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
22 |
-
model = XLNetModel.from_pretrained("xlnet-base-cased")
|
23 |
-
|
24 |
-
# Then load your custom weights
|
25 |
-
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu")))
|
26 |
-
model.eval()
|
27 |
-
|
28 |
-
def download_model_if_needed():
|
29 |
-
if os.path.exists(MODEL_PATH):
|
30 |
-
return
|
31 |
-
print("▶️ Downloading XLNet weights from Hugging Face …")
|
32 |
-
with requests.get(HF_URL, stream=True) as r:
|
33 |
-
r.raise_for_status()
|
34 |
-
total = int(r.headers.get("content-length", 0))
|
35 |
-
with open(MODEL_PATH, "wb") as f, tqdm.tqdm(total=total, unit="B", unit_scale=True) as bar:
|
36 |
-
for chunk in r.iter_content(chunk_size=8192):
|
37 |
-
f.write(chunk)
|
38 |
-
bar.update(len(chunk))
|
39 |
-
print("✅ Download complete.")
|
40 |
-
|
41 |
-
download_model_if_needed()
|
42 |
-
# ------------------------------------------------------------------
|
43 |
-
|
44 |
-
xlnet_available = False # will flip to True if load succeeds
|
45 |
-
|
46 |
class XLNetAnswerAssessmentModel(nn.Module):
|
47 |
def __init__(self):
|
48 |
super().__init__()
|
@@ -58,6 +32,8 @@ class XLNetAnswerAssessmentModel(nn.Module):
|
|
58 |
x = torch.relu(self.fc2(x))
|
59 |
return torch.sigmoid(self.out(x))
|
60 |
|
|
|
|
|
61 |
try:
|
62 |
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
63 |
model = XLNetAnswerAssessmentModel()
|
@@ -68,11 +44,9 @@ try:
|
|
68 |
except Exception as e:
|
69 |
print("⚠️ Could not load XLNet model → fallback to TF‑IDF\n", e)
|
70 |
|
71 |
-
#
|
72 |
-
#
|
73 |
-
#
|
74 |
-
embedding_cache = {}
|
75 |
-
|
76 |
def get_model_prediction(q, s, r):
|
77 |
if not xlnet_available:
|
78 |
raise ValueError("XLNet unavailable")
|
@@ -95,4 +69,4 @@ def get_similarity_score(q, s, r):
|
|
95 |
try:
|
96 |
return get_model_prediction(q, s, r) if xlnet_available else tfidf_similarity(s, r)
|
97 |
except Exception:
|
98 |
-
return fallback_similarity(s, r)
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torch import nn
|
5 |
from transformers import XLNetModel, XLNetTokenizer
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
8 |
from huggingface_hub import hf_hub_download
|
|
|
9 |
|
10 |
+
# Set Hugging Face cache directory
|
11 |
+
os.environ["HF_HOME"] = "/tmp/huggingface"
|
12 |
+
|
13 |
+
# Download model weights from Hugging Face Hub
|
14 |
MODEL_PATH = hf_hub_download(
|
15 |
repo_id="yeswanthvarma/xlnet-evaluator-model",
|
16 |
filename="xlnet_answer_assessment_model.pt"
|
17 |
)
|
18 |
|
19 |
+
# Define your custom model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class XLNetAnswerAssessmentModel(nn.Module):
|
21 |
def __init__(self):
|
22 |
super().__init__()
|
|
|
32 |
x = torch.relu(self.fc2(x))
|
33 |
return torch.sigmoid(self.out(x))
|
34 |
|
35 |
+
# Initialize model and tokenizer
|
36 |
+
xlnet_available = False
|
37 |
try:
|
38 |
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
39 |
model = XLNetAnswerAssessmentModel()
|
|
|
44 |
except Exception as e:
|
45 |
print("⚠️ Could not load XLNet model → fallback to TF‑IDF\n", e)
|
46 |
|
47 |
+
# -------------------------------
|
48 |
+
# Scoring logic
|
49 |
+
# -------------------------------
|
|
|
|
|
50 |
def get_model_prediction(q, s, r):
|
51 |
if not xlnet_available:
|
52 |
raise ValueError("XLNet unavailable")
|
|
|
69 |
try:
|
70 |
return get_model_prediction(q, s, r) if xlnet_available else tfidf_similarity(s, r)
|
71 |
except Exception:
|
72 |
+
return fallback_similarity(s, r)
|