SiddharthAK commited on
Commit
5bf8193
·
verified ·
1 Parent(s): 5e87f41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -38
app.py CHANGED
@@ -79,10 +79,10 @@ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
79
 
80
 
81
  # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
82
- # These functions take single text input for the Explorer tab
83
  def get_splade_cocondenser_representation(text):
84
  if tokenizer_splade is None or model_splade is None:
85
- return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
86
 
87
  inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
88
  inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
@@ -96,7 +96,7 @@ def get_splade_cocondenser_representation(text):
96
  dim=1
97
  )[0].squeeze() # Squeeze is fine here as it's a single input
98
  else:
99
- return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
100
 
101
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
102
  if not isinstance(indices, list):
@@ -120,16 +120,16 @@ def get_splade_cocondenser_representation(text):
120
  for term, weight in sorted_representation:
121
  formatted_output += f"- **{term}**: {weight:.4f}\n"
122
 
123
- formatted_output += "\n--- Sparse Vector Info ---\n"
124
- formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
125
- formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
126
 
127
- return formatted_output
128
 
129
 
130
  def get_splade_lexical_representation(text):
131
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
132
- return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors."
133
 
134
  inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
135
  inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
@@ -143,7 +143,7 @@ def get_splade_lexical_representation(text):
143
  dim=1
144
  )[0].squeeze() # Squeeze is fine here
145
  else:
146
- return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
147
 
148
  # Always apply lexical mask for this model's specific behavior
149
  vocab_size = tokenizer_splade_lexical.vocab_size
@@ -175,16 +175,16 @@ def get_splade_lexical_representation(text):
175
  for term, weight in sorted_representation:
176
  formatted_output += f"- **{term}**: {weight:.4f}\n"
177
 
178
- formatted_output += "\n--- Raw Sparse Vector Info ---\n"
179
- formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
180
- formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n"
181
 
182
- return formatted_output
183
 
184
 
185
  def get_splade_doc_representation(text):
186
  if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
187
- return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors."
188
 
189
  inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
190
  inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
@@ -220,11 +220,11 @@ def get_splade_doc_representation(text):
220
  break
221
  formatted_output += f"- **{term}**\n"
222
 
223
- formatted_output += "\n--- Raw Binary Bag-of-Words Vector Info ---\n" # Changed title
224
- formatted_output += f"Total activated terms: {len(indices)}\n"
225
- formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
226
 
227
- return formatted_output
228
 
229
 
230
  # --- Unified Prediction Function for the Explorer Tab ---
@@ -236,7 +236,7 @@ def predict_representation_explorer(model_choice, text):
236
  elif model_choice == "Binary Bag-of-Words": # Changed name
237
  return get_splade_doc_representation(text)
238
  else:
239
- return "Please select a model."
240
 
241
  # --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
242
  # These functions remain unchanged from the previous iteration, as they return the raw tensors.
@@ -339,10 +339,10 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
339
  else:
340
  formatted_output += f"- **{term}**: {weight:.4f}\n"
341
 
342
- formatted_output += f"\nTotal non-zero terms: {len(indices)}\n"
343
- formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n"
344
 
345
- return formatted_output
346
 
347
 
348
  # --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
@@ -376,11 +376,16 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
376
  dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
377
 
378
  # Format representations
 
 
 
 
 
379
  query_rep_str = f"Query Representation ({query_model_name_display}):\n"
380
- query_rep_str += format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
381
 
382
  doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
383
- doc_rep_str += format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
384
 
385
  # Combine output
386
  full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
@@ -397,30 +402,50 @@ with gr.Blocks(title="SPLADE Demos") as demo:
397
 
398
  with gr.Tabs():
399
  with gr.TabItem("Sparse Representation"):
400
- gr.Markdown("### Produce a Sparse Representation of of an Input Text")
401
- gr.Interface(
402
- fn=predict_representation_explorer,
403
- inputs=[
404
- gr.Radio(
405
  [
406
  "MLM encoder (SPLADE-cocondenser-distil)",
407
  "MLP encoder (SPLADE-v3-lexical)",
408
- "Binary Bag-of-Words" # Changed name here
409
  ],
410
  label="Choose Sparse Encoder",
411
  value="MLM encoder (SPLADE-cocondenser-distil)"
412
- ),
413
- gr.Textbox(
414
  lines=5,
415
  label="Enter your query or document text here:",
416
  placeholder="e.g., Why is Padua the nicest city in Italy?"
417
  )
418
- ],
419
- outputs=gr.Markdown(),
420
- allow_flagging="never",
421
- # live=True # Setting live=True might be slow for complex models on every keystroke
 
 
 
 
 
 
 
 
 
 
422
  )
423
-
 
 
 
 
 
 
 
 
 
 
 
424
  with gr.TabItem("Compare Encoders"): # NEW TAB
425
  gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
426
  gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
@@ -429,7 +454,7 @@ with gr.Blocks(title="SPLADE Demos") as demo:
429
  model_choices = [
430
  "MLM encoder (SPLADE-cocondenser-distil)",
431
  "MLP encoder (SPLADE-v3-lexical)",
432
- "Binary Bag-of-Words" # Changed name here
433
  ]
434
 
435
  gr.Interface(
 
79
 
80
 
81
  # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
82
+ # These functions now return a tuple: (main_representation_str, info_str)
83
  def get_splade_cocondenser_representation(text):
84
  if tokenizer_splade is None or model_splade is None:
85
+ return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", ""
86
 
87
  inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
88
  inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
 
96
  dim=1
97
  )[0].squeeze() # Squeeze is fine here as it's a single input
98
  else:
99
+ return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", ""
100
 
101
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
102
  if not isinstance(indices, list):
 
120
  for term, weight in sorted_representation:
121
  formatted_output += f"- **{term}**: {weight:.4f}\n"
122
 
123
+ info_output = f"--- Sparse Vector Info ---\n"
124
+ info_output += f"Total non-zero terms in vector: {len(indices)}\n"
125
+ info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
126
 
127
+ return formatted_output, info_output
128
 
129
 
130
  def get_splade_lexical_representation(text):
131
  if tokenizer_splade_lexical is None or model_splade_lexical is None:
132
+ return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", ""
133
 
134
  inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
135
  inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
 
143
  dim=1
144
  )[0].squeeze() # Squeeze is fine here
145
  else:
146
+ return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", ""
147
 
148
  # Always apply lexical mask for this model's specific behavior
149
  vocab_size = tokenizer_splade_lexical.vocab_size
 
175
  for term, weight in sorted_representation:
176
  formatted_output += f"- **{term}**: {weight:.4f}\n"
177
 
178
+ info_output = f"--- Raw Sparse Vector Info ---\n"
179
+ info_output += f"Total non-zero terms in vector: {len(indices)}\n"
180
+ info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n"
181
 
182
+ return formatted_output, info_output
183
 
184
 
185
  def get_splade_doc_representation(text):
186
  if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
187
+ return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", ""
188
 
189
  inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
190
  inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
 
220
  break
221
  formatted_output += f"- **{term}**\n"
222
 
223
+ info_output = f"--- Raw Binary Bag-of-Words Vector Info ---\n" # Changed title
224
+ info_output += f"Total activated terms: {len(indices)}\n"
225
+ info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
226
 
227
+ return formatted_output, info_output
228
 
229
 
230
  # --- Unified Prediction Function for the Explorer Tab ---
 
236
  elif model_choice == "Binary Bag-of-Words": # Changed name
237
  return get_splade_doc_representation(text)
238
  else:
239
+ return "Please select a model.", "" # Return two empty strings for consistency
240
 
241
  # --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
242
  # These functions remain unchanged from the previous iteration, as they return the raw tensors.
 
339
  else:
340
  formatted_output += f"- **{term}**: {weight:.4f}\n"
341
 
342
+ info_output = f"\nTotal non-zero terms: {len(indices)}\n"
343
+ info_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n"
344
 
345
+ return formatted_output, info_output # Now returns two strings
346
 
347
 
348
  # --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
 
376
  dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
377
 
378
  # Format representations
379
+ # These functions now return two strings (main_output, info_output)
380
+ query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
381
+ doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
382
+
383
+
384
  query_rep_str = f"Query Representation ({query_model_name_display}):\n"
385
+ query_rep_str += query_main_rep_str + "\n" + query_info_str
386
 
387
  doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
388
+ doc_rep_str += doc_main_rep_str + "\n" + doc_info_str
389
 
390
  # Combine output
391
  full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
 
402
 
403
  with gr.Tabs():
404
  with gr.TabItem("Sparse Representation"):
405
+ gr.Markdown("### Produce a Sparse Representation of an Input Text")
406
+ with gr.Row():
407
+ with gr.Column(scale=1): # Left column for inputs and info
408
+ model_radio = gr.Radio(
 
409
  [
410
  "MLM encoder (SPLADE-cocondenser-distil)",
411
  "MLP encoder (SPLADE-v3-lexical)",
412
+ "Binary Bag-of-Words"
413
  ],
414
  label="Choose Sparse Encoder",
415
  value="MLM encoder (SPLADE-cocondenser-distil)"
416
+ )
417
+ input_text = gr.Textbox(
418
  lines=5,
419
  label="Enter your query or document text here:",
420
  placeholder="e.g., Why is Padua the nicest city in Italy?"
421
  )
422
+ # New Markdown component for the info output
423
+ info_output_display = gr.Markdown(
424
+ value="",
425
+ label="Vector Information",
426
+ elem_id="info_output_display" # Add an ID for potential CSS if needed
427
+ )
428
+ with gr.Column(scale=2): # Right column for the main representation output
429
+ main_representation_output = gr.Markdown()
430
+
431
+ # Connect the interface elements
432
+ model_radio.change(
433
+ fn=predict_representation_explorer,
434
+ inputs=[model_radio, input_text],
435
+ outputs=[main_representation_output, info_output_display]
436
  )
437
+ input_text.change(
438
+ fn=predict_representation_explorer,
439
+ inputs=[model_radio, input_text],
440
+ outputs=[main_representation_output, info_output_display]
441
+ )
442
+
443
+ # Initial call to populate on load (optional, but good for demo)
444
+ demo.load(
445
+ fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
446
+ outputs=[main_representation_output, info_output_display]
447
+ )
448
+
449
  with gr.TabItem("Compare Encoders"): # NEW TAB
450
  gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
451
  gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
 
454
  model_choices = [
455
  "MLM encoder (SPLADE-cocondenser-distil)",
456
  "MLP encoder (SPLADE-v3-lexical)",
457
+ "Binary Bag-of-Words"
458
  ]
459
 
460
  gr.Interface(