SiddharthAK commited on
Commit
ab88097
·
verified ·
1 Parent(s): c1ea33e

removed uncoil, added spladev3 lexical

Browse files
Files changed (1) hide show
  1. app.py +53 -66
app.py CHANGED
@@ -5,39 +5,36 @@ import torch
5
  # --- Model Loading ---
6
  tokenizer_splade = None
7
  model_splade = None
8
- tokenizer_unicoil = None
9
- model_unicoil = None
10
 
11
- # Load SPLADE v3 model
12
  try:
13
  tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
14
  model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
15
  model_splade.eval() # Set to evaluation mode for inference
16
- print("SPLADE v3 model loaded successfully!")
17
  except Exception as e:
18
- print(f"Error loading SPLADE model: {e}")
19
  print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
20
 
21
- # Load UNICOIL model for binary sparse encoding
22
- # Load UNICOIL model for binary sparse encoding
23
  try:
24
- unicoil_model_name = "castorini/unicoil-msmarco-passage"
25
- tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name)
26
- # --- FIX IS HERE ---
27
- model_unicoil = AutoModelForMaskedLM.from_pretrained(unicoil_model_name)
28
- # -------------------
29
- model_unicoil.eval() # Set to evaluation mode for inference
30
- print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!")
31
  except Exception as e:
32
- print(f"Error loading UNICOIL model: {e}")
33
- print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
34
 
35
 
36
  # --- Core Representation Functions ---
37
 
38
  def get_splade_representation(text):
39
  if tokenizer_splade is None or model_splade is None:
40
- return "SPLADE model is not loaded. Please check the console for loading errors."
41
 
42
  inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
43
  inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
@@ -51,7 +48,7 @@ def get_splade_representation(text):
51
  dim=1
52
  )[0].squeeze()
53
  else:
54
- return "Model output structure not as expected for SPLADE. 'logits' not found."
55
 
56
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
57
  if not isinstance(indices, list):
@@ -68,7 +65,7 @@ def get_splade_representation(text):
68
 
69
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
70
 
71
- formatted_output = "SPLADE Representation (All Non-Zero Terms):\n"
72
  if not sorted_representation:
73
  formatted_output += "No significant terms found for this input.\n"
74
  else:
@@ -82,69 +79,59 @@ def get_splade_representation(text):
82
  return formatted_output
83
 
84
 
 
 
 
85
 
86
-
87
- def get_unicoil_binary_representation(text):
88
- if tokenizer_unicoil is None or model_unicoil is None:
89
- return "UNICOIL model is not loaded. Please check the console for loading errors."
90
-
91
- inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
92
- input_ids = inputs["input_ids"]
93
- attention_mask = inputs["attention_mask"]
94
- inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
95
 
96
  with torch.no_grad():
97
- output = model_unicoil(**inputs)
98
 
99
- if not hasattr(output, "logits"):
100
- return "UNICOIL model output structure not as expected. 'logits' not found."
101
-
102
- logits = output.logits.squeeze(0) # [seq_len, vocab_size]
103
- token_ids = input_ids.squeeze(0) # [seq_len]
104
- mask = attention_mask.squeeze(0) # [seq_len]
 
105
 
106
- transformed_scores = torch.log(1 + torch.exp(logits)) # softplus
107
- token_scores = transformed_scores[range(len(token_ids)), token_ids] # only scores for input tokens
108
- token_scores = token_scores * mask # mask out padding
109
 
110
- # Binarize: threshold scores > 0.5 (tune as needed)
111
- binary_mask = (token_scores > 0.5)
112
- activated_token_ids = token_ids[binary_mask].cpu().tolist()
113
 
114
- # Map token ids to strings
115
- binary_terms = {}
116
- for token_id in activated_token_ids:
117
- decoded_token = tokenizer_unicoil.decode([token_id])
118
  if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
119
- binary_terms[decoded_token] = 1
120
 
121
- sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0])
122
 
123
- formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
124
- if not sorted_binary_terms:
125
- formatted_output += "No significant terms activated for this input.\n"
126
  else:
127
- for i, (term, _) in enumerate(sorted_binary_terms):
128
- if i >= 50:
129
- formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
130
- break
131
- formatted_output += f"- **{term}**\n"
132
 
133
- formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
134
- formatted_output += f"Total activated terms: {len(sorted_binary_terms)}\n"
135
- formatted_output += f"Sparsity: {1 - (len(sorted_binary_terms) / tokenizer_unicoil.vocab_size):.2%}\n"
136
 
137
  return formatted_output
138
 
139
 
140
-
141
-
142
  # --- Unified Prediction Function for Gradio ---
143
  def predict_representation(model_choice, text):
144
- if model_choice == "SPLADE":
145
  return get_splade_representation(text)
146
- elif model_choice == "UNICOIL (Binary Sparse)":
147
- return get_unicoil_binary_representation(text)
148
  else:
149
  return "Please select a model."
150
 
@@ -153,9 +140,9 @@ demo = gr.Interface(
153
  fn=predict_representation,
154
  inputs=[
155
  gr.Radio(
156
- ["SPLADE", "UNICOIL"], # Added UNICOIL option
157
  label="Choose Representation Model",
158
- value="SPLADE" # Default selection
159
  ),
160
  gr.Textbox(
161
  lines=5,
@@ -165,7 +152,7 @@ demo = gr.Interface(
165
  ],
166
  outputs=gr.Markdown(),
167
  title="🌌 Sparse and Binary Sparse Representation Generator",
168
- description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.",
169
  allow_flagging="never"
170
  )
171
 
 
5
  # --- Model Loading ---
6
  tokenizer_splade = None
7
  model_splade = None
8
+ tokenizer_splade_lexical = None
9
+ model_splade_lexical = None
10
 
11
+ # Load SPLADE v3 model (original)
12
  try:
13
  tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
14
  model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
15
  model_splade.eval() # Set to evaluation mode for inference
16
+ print("SPLADE v3 (cocondenser) model loaded successfully!")
17
  except Exception as e:
18
+ print(f"Error loading SPLADE (cocondenser) model: {e}")
19
  print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
20
 
21
+ # Load SPLADE v3 Lexical model
 
22
  try:
23
+ splade_lexical_model_name = "naver/splade-v3-lexical"
24
+ tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
25
+ model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
26
+ model_splade_lexical.eval() # Set to evaluation mode for inference
27
+ print(f"SPLADE v3 Lexical model '{splade_lexical_model_name}' loaded successfully!")
 
 
28
  except Exception as e:
29
+ print(f"Error loading SPLADE v3 Lexical model: {e}")
30
+ print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
31
 
32
 
33
  # --- Core Representation Functions ---
34
 
35
  def get_splade_representation(text):
36
  if tokenizer_splade is None or model_splade is None:
37
+ return "SPLADE (cocondenser) model is not loaded. Please check the console for loading errors."
38
 
39
  inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
40
  inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
 
48
  dim=1
49
  )[0].squeeze()
50
  else:
51
+ return "Model output structure not as expected for SPLADE (cocondenser). 'logits' not found."
52
 
53
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
54
  if not isinstance(indices, list):
 
65
 
66
  sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
67
 
68
+ formatted_output = "SPLADE (cocondenser) Representation (All Non-Zero Terms):\n"
69
  if not sorted_representation:
70
  formatted_output += "No significant terms found for this input.\n"
71
  else:
 
79
  return formatted_output
80
 
81
 
82
+ def get_splade_lexical_representation(text):
83
+ if tokenizer_splade_lexical is None or model_splade_lexical is None:
84
+ return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
85
 
86
+ inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
87
+ inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
 
 
 
 
 
 
 
88
 
89
  with torch.no_grad():
90
+ output = model_splade_lexical(**inputs)
91
 
92
+ if hasattr(output, 'logits'):
93
+ splade_vector = torch.max(
94
+ torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
95
+ dim=1
96
+ )[0].squeeze()
97
+ else:
98
+ return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
99
 
100
+ indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
101
+ if not isinstance(indices, list):
102
+ indices = [indices]
103
 
104
+ values = splade_vector[indices].cpu().tolist()
105
+ token_weights = dict(zip(indices, values))
 
106
 
107
+ meaningful_tokens = {}
108
+ for token_id, weight in token_weights.items():
109
+ decoded_token = tokenizer_splade_lexical.decode([token_id])
 
110
  if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
111
+ meaningful_tokens[decoded_token] = weight
112
 
113
+ sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
114
 
115
+ formatted_output = "SPLADE v3 Lexical Representation (All Non-Zero Terms):\n"
116
+ if not sorted_representation:
117
+ formatted_output += "No significant terms found for this input.\n"
118
  else:
119
+ for term, weight in sorted_representation:
120
+ formatted_output += f"- **{term}**: {weight:.4f}\n"
 
 
 
121
 
122
+ formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
123
+ formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
124
+ formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n"
125
 
126
  return formatted_output
127
 
128
 
 
 
129
  # --- Unified Prediction Function for Gradio ---
130
  def predict_representation(model_choice, text):
131
+ if model_choice == "SPLADE (cocondenser)":
132
  return get_splade_representation(text)
133
+ elif model_choice == "SPLADE-v3-Lexical":
134
+ return get_splade_lexical_representation(text)
135
  else:
136
  return "Please select a model."
137
 
 
140
  fn=predict_representation,
141
  inputs=[
142
  gr.Radio(
143
+ ["SPLADE (cocondenser)", "SPLADE-v3-Lexical"], # Updated options
144
  label="Choose Representation Model",
145
+ value="SPLADE (cocondenser)" # Default selection
146
  ),
147
  gr.Textbox(
148
  lines=5,
 
152
  ],
153
  outputs=gr.Markdown(),
154
  title="🌌 Sparse and Binary Sparse Representation Generator",
155
+ description="Enter any text to see its SPLADE sparse vector or SPLADE-v3-Lexical representation.",
156
  allow_flagging="never"
157
  )
158