yeswanthvarma commited on
Commit
c428c79
·
verified ·
1 Parent(s): 1af9d91

Update utils/xlnet_model.py

Browse files
Files changed (1) hide show
  1. utils/xlnet_model.py +14 -40
utils/xlnet_model.py CHANGED
@@ -1,48 +1,22 @@
1
- import requests, tqdm, torch, numpy as np
2
- from torch import nn
3
  import os
4
- os.environ["HF_HOME"] = "/tmp/huggingface"
 
 
5
  from transformers import XLNetModel, XLNetTokenizer
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
-
9
- # ------------------------------------------------------------------
10
  from huggingface_hub import hf_hub_download
11
- import torch
12
 
13
- # Downloads the file automatically from your model repo
 
 
 
14
  MODEL_PATH = hf_hub_download(
15
  repo_id="yeswanthvarma/xlnet-evaluator-model",
16
  filename="xlnet_answer_assessment_model.pt"
17
  )
18
 
19
- # Load the model
20
- # Load tokenizer and model architecture
21
- tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
22
- model = XLNetModel.from_pretrained("xlnet-base-cased")
23
-
24
- # Then load your custom weights
25
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu")))
26
- model.eval()
27
-
28
- def download_model_if_needed():
29
- if os.path.exists(MODEL_PATH):
30
- return
31
- print("▶️ Downloading XLNet weights from Hugging Face …")
32
- with requests.get(HF_URL, stream=True) as r:
33
- r.raise_for_status()
34
- total = int(r.headers.get("content-length", 0))
35
- with open(MODEL_PATH, "wb") as f, tqdm.tqdm(total=total, unit="B", unit_scale=True) as bar:
36
- for chunk in r.iter_content(chunk_size=8192):
37
- f.write(chunk)
38
- bar.update(len(chunk))
39
- print("✅ Download complete.")
40
-
41
- download_model_if_needed()
42
- # ------------------------------------------------------------------
43
-
44
- xlnet_available = False # will flip to True if load succeeds
45
-
46
  class XLNetAnswerAssessmentModel(nn.Module):
47
  def __init__(self):
48
  super().__init__()
@@ -58,6 +32,8 @@ class XLNetAnswerAssessmentModel(nn.Module):
58
  x = torch.relu(self.fc2(x))
59
  return torch.sigmoid(self.out(x))
60
 
 
 
61
  try:
62
  tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
63
  model = XLNetAnswerAssessmentModel()
@@ -68,11 +44,9 @@ try:
68
  except Exception as e:
69
  print("⚠️ Could not load XLNet model → fallback to TF‑IDF\n", e)
70
 
71
- # ------------------------------------------------------------------
72
- # scoring helpers (unchanged)
73
- # ------------------------------------------------------------------
74
- embedding_cache = {}
75
-
76
  def get_model_prediction(q, s, r):
77
  if not xlnet_available:
78
  raise ValueError("XLNet unavailable")
@@ -95,4 +69,4 @@ def get_similarity_score(q, s, r):
95
  try:
96
  return get_model_prediction(q, s, r) if xlnet_available else tfidf_similarity(s, r)
97
  except Exception:
98
- return fallback_similarity(s, r)
 
 
 
1
  import os
2
+ import torch
3
+ import numpy as np
4
+ from torch import nn
5
  from transformers import XLNetModel, XLNetTokenizer
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  from sklearn.metrics.pairwise import cosine_similarity
 
 
8
  from huggingface_hub import hf_hub_download
 
9
 
10
+ # Set Hugging Face cache directory
11
+ os.environ["HF_HOME"] = "/tmp/huggingface"
12
+
13
+ # Download model weights from Hugging Face Hub
14
  MODEL_PATH = hf_hub_download(
15
  repo_id="yeswanthvarma/xlnet-evaluator-model",
16
  filename="xlnet_answer_assessment_model.pt"
17
  )
18
 
19
+ # Define your custom model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class XLNetAnswerAssessmentModel(nn.Module):
21
  def __init__(self):
22
  super().__init__()
 
32
  x = torch.relu(self.fc2(x))
33
  return torch.sigmoid(self.out(x))
34
 
35
+ # Initialize model and tokenizer
36
+ xlnet_available = False
37
  try:
38
  tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
39
  model = XLNetAnswerAssessmentModel()
 
44
  except Exception as e:
45
  print("⚠️ Could not load XLNet model → fallback to TF‑IDF\n", e)
46
 
47
+ # -------------------------------
48
+ # Scoring logic
49
+ # -------------------------------
 
 
50
  def get_model_prediction(q, s, r):
51
  if not xlnet_available:
52
  raise ValueError("XLNet unavailable")
 
69
  try:
70
  return get_model_prediction(q, s, r) if xlnet_available else tfidf_similarity(s, r)
71
  except Exception:
72
+ return fallback_similarity(s, r)