answer-evaluation-app / utils /xlnet_model.py
yeswanthvarma's picture
Update utils/xlnet_model.py
7963a78 verified
import os
import torch
from torch import nn
from transformers import XLNetModel, XLNetTokenizer
from huggingface_hub import hf_hub_download
# Set Hugging Face cache directory
os.environ["HF_HOME"] = "/tmp/huggingface"
# Download trained model weights from Hugging Face Hub
MODEL_PATH = hf_hub_download(
repo_id="yeswanthvarma/xlnet-evaluator-model",
filename="xlnet_answer_assessment_model.pt"
)
# Define your trained model architecture
class XLNetAnswerAssessmentModel(nn.Module):
def __init__(self):
super().__init__()
self.xlnet = XLNetModel.from_pretrained("xlnet-base-cased")
self.fc1 = nn.Linear(768, 256)
self.fc2 = nn.Linear(256, 64)
self.output = nn.Linear(64, 1)
def forward(self, input_ids, attention_mask=None):
pooled = self.xlnet(input_ids, attention_mask).last_hidden_state.mean(dim=1)
x = torch.relu(self.fc1(pooled))
x = torch.relu(self.fc2(x))
return torch.sigmoid(self.output(x)) # Output: score in range [0, 1]
# Load tokenizer and model
xlnet_available = False
try:
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
model = XLNetAnswerAssessmentModel()
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()
xlnet_available = True
print("βœ… Custom XLNet model loaded.")
except Exception as e:
print("⚠️ Could not load XLNet model β†’ fallback will be used\n", e)
# -------------------------------
# Main prediction function
# -------------------------------
def get_model_prediction(q, s, r):
if not xlnet_available:
raise RuntimeError("XLNet model not available")
# Combine input text as during training
combined = f"{q} [SEP] {s} [SEP] {r}"
inputs = tokenizer(combined, return_tensors="pt", truncation=True, max_length=512, padding=True)
with torch.no_grad():
output = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"]
)
score = output.squeeze().item() * 100 # Convert from [0,1] β†’ [0,100]
return round(score)
# Optional: Fallback similarity using word overlap
def fallback_similarity(t1, t2):
w1, w2 = set(t1.lower().split()), set(t2.lower().split())
return round(len(w1 & w2) / len(w1 | w2) * 100) if w1 and w2 else 0
# Final score API (use in app.py)
def get_similarity_score(q, s, r):
try:
return get_model_prediction(q, s, r)
except Exception as e:
print("❌ XLNet failed, using fallback:", e)
return fallback_similarity(s, r)