Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -79,10 +79,10 @@ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
|
79 |
|
80 |
|
81 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
82 |
-
# These functions
|
83 |
def get_splade_cocondenser_representation(text):
|
84 |
if tokenizer_splade is None or model_splade is None:
|
85 |
-
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
86 |
|
87 |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
88 |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
@@ -96,7 +96,7 @@ def get_splade_cocondenser_representation(text):
|
|
96 |
dim=1
|
97 |
)[0].squeeze() # Squeeze is fine here as it's a single input
|
98 |
else:
|
99 |
-
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found."
|
100 |
|
101 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
102 |
if not isinstance(indices, list):
|
@@ -120,16 +120,16 @@ def get_splade_cocondenser_representation(text):
|
|
120 |
for term, weight in sorted_representation:
|
121 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
|
127 |
-
return formatted_output
|
128 |
|
129 |
|
130 |
def get_splade_lexical_representation(text):
|
131 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
132 |
-
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors."
|
133 |
|
134 |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
|
135 |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
|
@@ -143,7 +143,7 @@ def get_splade_lexical_representation(text):
|
|
143 |
dim=1
|
144 |
)[0].squeeze() # Squeeze is fine here
|
145 |
else:
|
146 |
-
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found."
|
147 |
|
148 |
# Always apply lexical mask for this model's specific behavior
|
149 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
@@ -175,16 +175,16 @@ def get_splade_lexical_representation(text):
|
|
175 |
for term, weight in sorted_representation:
|
176 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
|
182 |
-
return formatted_output
|
183 |
|
184 |
|
185 |
def get_splade_doc_representation(text):
|
186 |
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
|
187 |
-
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors."
|
188 |
|
189 |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
190 |
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
|
@@ -220,11 +220,11 @@ def get_splade_doc_representation(text):
|
|
220 |
break
|
221 |
formatted_output += f"- **{term}**\n"
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
|
227 |
-
return formatted_output
|
228 |
|
229 |
|
230 |
# --- Unified Prediction Function for the Explorer Tab ---
|
@@ -236,7 +236,7 @@ def predict_representation_explorer(model_choice, text):
|
|
236 |
elif model_choice == "Binary Bag-of-Words": # Changed name
|
237 |
return get_splade_doc_representation(text)
|
238 |
else:
|
239 |
-
return "Please select a model."
|
240 |
|
241 |
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
|
242 |
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
|
@@ -339,10 +339,10 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
|
339 |
else:
|
340 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
341 |
|
342 |
-
|
343 |
-
|
344 |
|
345 |
-
return formatted_output
|
346 |
|
347 |
|
348 |
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
|
@@ -376,11 +376,16 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
|
|
376 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
377 |
|
378 |
# Format representations
|
|
|
|
|
|
|
|
|
|
|
379 |
query_rep_str = f"Query Representation ({query_model_name_display}):\n"
|
380 |
-
query_rep_str +=
|
381 |
|
382 |
doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
|
383 |
-
doc_rep_str +=
|
384 |
|
385 |
# Combine output
|
386 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
@@ -397,30 +402,50 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
397 |
|
398 |
with gr.Tabs():
|
399 |
with gr.TabItem("Sparse Representation"):
|
400 |
-
gr.Markdown("### Produce a Sparse Representation of
|
401 |
-
gr.
|
402 |
-
|
403 |
-
|
404 |
-
gr.Radio(
|
405 |
[
|
406 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
407 |
"MLP encoder (SPLADE-v3-lexical)",
|
408 |
-
"Binary Bag-of-Words"
|
409 |
],
|
410 |
label="Choose Sparse Encoder",
|
411 |
value="MLM encoder (SPLADE-cocondenser-distil)"
|
412 |
-
)
|
413 |
-
gr.Textbox(
|
414 |
lines=5,
|
415 |
label="Enter your query or document text here:",
|
416 |
placeholder="e.g., Why is Padua the nicest city in Italy?"
|
417 |
)
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
)
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
with gr.TabItem("Compare Encoders"): # NEW TAB
|
425 |
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
|
426 |
gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
|
@@ -429,7 +454,7 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
429 |
model_choices = [
|
430 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
431 |
"MLP encoder (SPLADE-v3-lexical)",
|
432 |
-
"Binary Bag-of-Words"
|
433 |
]
|
434 |
|
435 |
gr.Interface(
|
|
|
79 |
|
80 |
|
81 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
82 |
+
# These functions now return a tuple: (main_representation_str, info_str)
|
83 |
def get_splade_cocondenser_representation(text):
|
84 |
if tokenizer_splade is None or model_splade is None:
|
85 |
+
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", ""
|
86 |
|
87 |
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
|
88 |
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
|
|
|
96 |
dim=1
|
97 |
)[0].squeeze() # Squeeze is fine here as it's a single input
|
98 |
else:
|
99 |
+
return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", ""
|
100 |
|
101 |
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
|
102 |
if not isinstance(indices, list):
|
|
|
120 |
for term, weight in sorted_representation:
|
121 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
122 |
|
123 |
+
info_output = f"--- Sparse Vector Info ---\n"
|
124 |
+
info_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
125 |
+
info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
|
126 |
|
127 |
+
return formatted_output, info_output
|
128 |
|
129 |
|
130 |
def get_splade_lexical_representation(text):
|
131 |
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
132 |
+
return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", ""
|
133 |
|
134 |
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
|
135 |
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
|
|
|
143 |
dim=1
|
144 |
)[0].squeeze() # Squeeze is fine here
|
145 |
else:
|
146 |
+
return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", ""
|
147 |
|
148 |
# Always apply lexical mask for this model's specific behavior
|
149 |
vocab_size = tokenizer_splade_lexical.vocab_size
|
|
|
175 |
for term, weight in sorted_representation:
|
176 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
177 |
|
178 |
+
info_output = f"--- Raw Sparse Vector Info ---\n"
|
179 |
+
info_output += f"Total non-zero terms in vector: {len(indices)}\n"
|
180 |
+
info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n"
|
181 |
|
182 |
+
return formatted_output, info_output
|
183 |
|
184 |
|
185 |
def get_splade_doc_representation(text):
|
186 |
if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
|
187 |
+
return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", ""
|
188 |
|
189 |
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
|
190 |
inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation
|
|
|
220 |
break
|
221 |
formatted_output += f"- **{term}**\n"
|
222 |
|
223 |
+
info_output = f"--- Raw Binary Bag-of-Words Vector Info ---\n" # Changed title
|
224 |
+
info_output += f"Total activated terms: {len(indices)}\n"
|
225 |
+
info_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
|
226 |
|
227 |
+
return formatted_output, info_output
|
228 |
|
229 |
|
230 |
# --- Unified Prediction Function for the Explorer Tab ---
|
|
|
236 |
elif model_choice == "Binary Bag-of-Words": # Changed name
|
237 |
return get_splade_doc_representation(text)
|
238 |
else:
|
239 |
+
return "Please select a model.", "" # Return two empty strings for consistency
|
240 |
|
241 |
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
|
242 |
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
|
|
|
339 |
else:
|
340 |
formatted_output += f"- **{term}**: {weight:.4f}\n"
|
341 |
|
342 |
+
info_output = f"\nTotal non-zero terms: {len(indices)}\n"
|
343 |
+
info_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n"
|
344 |
|
345 |
+
return formatted_output, info_output # Now returns two strings
|
346 |
|
347 |
|
348 |
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
|
|
|
376 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
377 |
|
378 |
# Format representations
|
379 |
+
# These functions now return two strings (main_output, info_output)
|
380 |
+
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
|
381 |
+
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
|
382 |
+
|
383 |
+
|
384 |
query_rep_str = f"Query Representation ({query_model_name_display}):\n"
|
385 |
+
query_rep_str += query_main_rep_str + "\n" + query_info_str
|
386 |
|
387 |
doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
|
388 |
+
doc_rep_str += doc_main_rep_str + "\n" + doc_info_str
|
389 |
|
390 |
# Combine output
|
391 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
|
|
402 |
|
403 |
with gr.Tabs():
|
404 |
with gr.TabItem("Sparse Representation"):
|
405 |
+
gr.Markdown("### Produce a Sparse Representation of an Input Text")
|
406 |
+
with gr.Row():
|
407 |
+
with gr.Column(scale=1): # Left column for inputs and info
|
408 |
+
model_radio = gr.Radio(
|
|
|
409 |
[
|
410 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
411 |
"MLP encoder (SPLADE-v3-lexical)",
|
412 |
+
"Binary Bag-of-Words"
|
413 |
],
|
414 |
label="Choose Sparse Encoder",
|
415 |
value="MLM encoder (SPLADE-cocondenser-distil)"
|
416 |
+
)
|
417 |
+
input_text = gr.Textbox(
|
418 |
lines=5,
|
419 |
label="Enter your query or document text here:",
|
420 |
placeholder="e.g., Why is Padua the nicest city in Italy?"
|
421 |
)
|
422 |
+
# New Markdown component for the info output
|
423 |
+
info_output_display = gr.Markdown(
|
424 |
+
value="",
|
425 |
+
label="Vector Information",
|
426 |
+
elem_id="info_output_display" # Add an ID for potential CSS if needed
|
427 |
+
)
|
428 |
+
with gr.Column(scale=2): # Right column for the main representation output
|
429 |
+
main_representation_output = gr.Markdown()
|
430 |
+
|
431 |
+
# Connect the interface elements
|
432 |
+
model_radio.change(
|
433 |
+
fn=predict_representation_explorer,
|
434 |
+
inputs=[model_radio, input_text],
|
435 |
+
outputs=[main_representation_output, info_output_display]
|
436 |
)
|
437 |
+
input_text.change(
|
438 |
+
fn=predict_representation_explorer,
|
439 |
+
inputs=[model_radio, input_text],
|
440 |
+
outputs=[main_representation_output, info_output_display]
|
441 |
+
)
|
442 |
+
|
443 |
+
# Initial call to populate on load (optional, but good for demo)
|
444 |
+
demo.load(
|
445 |
+
fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
|
446 |
+
outputs=[main_representation_output, info_output_display]
|
447 |
+
)
|
448 |
+
|
449 |
with gr.TabItem("Compare Encoders"): # NEW TAB
|
450 |
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
|
451 |
gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
|
|
|
454 |
model_choices = [
|
455 |
"MLM encoder (SPLADE-cocondenser-distil)",
|
456 |
"MLP encoder (SPLADE-v3-lexical)",
|
457 |
+
"Binary Bag-of-Words"
|
458 |
]
|
459 |
|
460 |
gr.Interface(
|