Spaces:
Running
Running
File size: 6,824 Bytes
a4de739 da3acda 3035463 a4de739 3035463 da3acda 37c6637 da3acda 3035463 da3acda 37c6637 da3acda 3035463 da3acda a4de739 3035463 da3acda 3035463 da3acda 3035463 da3acda 3035463 45b666a 3035463 45b666a 3035463 da3acda 3035463 45b666a 3035463 45b666a 3035463 45b666a 3035463 da3acda 45b666a da3acda 22a278f da3acda 22a278f da3acda 22a278f 3035463 da3acda 22a278f da3acda 3035463 da3acda 3035463 a4de739 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
import torch
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_unicoil = None
model_unicoil = None
# Load SPLADE v3 model
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval() # Set to evaluation mode for inference
print("SPLADE v3 model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load UNICOIL model for binary sparse encoding
# Load UNICOIL model for binary sparse encoding
try:
unicoil_model_name = "castorini/unicoil-msmarco-passage"
tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name)
# --- FIX IS HERE ---
model_unicoil = AutoModelForMaskedLM.from_pretrained(unicoil_model_name)
# -------------------
model_unicoil.eval() # Set to evaluation mode for inference
print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading UNICOIL model: {e}")
print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Core Representation Functions ---
def get_splade_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
else:
return "Model output structure not as expected for SPLADE. 'logits' not found."
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices]
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE Representation (All Non-Zero Terms):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for term, weight in sorted_representation:
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
return formatted_output
def get_unicoil_binary_representation(text):
if tokenizer_unicoil is None or model_unicoil is None:
return "UNICOIL model is not loaded. Please check the console for loading errors."
inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_unicoil(**inputs)
if not hasattr(output, "logits"):
return "UNICOIL model output structure not as expected. 'logits' not found."
logits = output.logits.squeeze(0) # [seq_len, vocab_size]
token_ids = input_ids.squeeze(0) # [seq_len]
mask = attention_mask.squeeze(0) # [seq_len]
transformed_scores = torch.log(1 + torch.exp(logits)) # softplus
token_scores = transformed_scores[range(len(token_ids)), token_ids] # only scores for input tokens
token_scores = token_scores * mask # mask out padding
# Binarize: threshold scores > 0.5 (tune as needed)
binary_mask = (token_scores > 0.5)
activated_token_ids = token_ids[binary_mask].cpu().tolist()
# Map token ids to strings
binary_terms = {}
for token_id in activated_token_ids:
decoded_token = tokenizer_unicoil.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
binary_terms[decoded_token] = 1
sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0])
formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
if not sorted_binary_terms:
formatted_output += "No significant terms activated for this input.\n"
else:
for i, (term, _) in enumerate(sorted_binary_terms):
if i >= 50:
formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
break
formatted_output += f"- **{term}**\n"
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
formatted_output += f"Total activated terms: {len(sorted_binary_terms)}\n"
formatted_output += f"Sparsity: {1 - (len(sorted_binary_terms) / tokenizer_unicoil.vocab_size):.2%}\n"
return formatted_output
# --- Unified Prediction Function for Gradio ---
def predict_representation(model_choice, text):
if model_choice == "SPLADE":
return get_splade_representation(text)
elif model_choice == "UNICOIL (Binary Sparse)":
return get_unicoil_binary_representation(text)
else:
return "Please select a model."
# --- Gradio Interface Setup ---
demo = gr.Interface(
fn=predict_representation,
inputs=[
gr.Radio(
["SPLADE", "UNICOIL (Binary Sparse)"], # Added UNICOIL option
label="Choose Representation Model",
value="SPLADE" # Default selection
),
gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
],
outputs=gr.Markdown(),
title="🌌 Sparse and Binary Sparse Representation Generator",
description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.",
allow_flagging="never"
)
# Launch the Gradio app
demo.launch() |