Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,12 +2,10 @@ import gradio as gr
|
|
2 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
-
from tqdm.auto import tqdm
|
6 |
-
import os
|
7 |
-
import ir_datasets
|
8 |
-
import random # Added for random selection
|
9 |
|
10 |
-
# --- Model Loading
|
11 |
tokenizer_splade = None
|
12 |
model_splade = None
|
13 |
tokenizer_splade_lexical = None
|
@@ -48,44 +46,7 @@ except Exception as e:
|
|
48 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
49 |
|
50 |
|
51 |
-
# ---
|
52 |
-
document_representations = {} # Stores {doc_id: sparse_vector}
|
53 |
-
document_texts = {} # Stores {doc_id: doc_text}
|
54 |
-
queries_texts = {} # Stores {query_id: query_text}
|
55 |
-
qrels_data = {} # Stores {query_id: [{doc_id: str, relevance: int}, ...]}
|
56 |
-
initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
|
57 |
-
|
58 |
-
|
59 |
-
# --- Load Cranfield Corpus, Queries, and Qrels using ir_datasets ---
|
60 |
-
def load_cranfield_corpus_ir_datasets():
|
61 |
-
global document_texts, queries_texts, qrels_data
|
62 |
-
print("Loading Cranfield corpus, queries, and qrels using ir_datasets...")
|
63 |
-
try:
|
64 |
-
dataset = ir_datasets.load("cranfield")
|
65 |
-
|
66 |
-
# Load documents
|
67 |
-
for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"):
|
68 |
-
document_texts[doc.doc_id] = doc.text.strip()
|
69 |
-
print(f"Loaded {len(document_texts)} documents from Cranfield corpus.")
|
70 |
-
|
71 |
-
# Load queries
|
72 |
-
for query in tqdm(dataset.queries_iter(), desc="Loading Cranfield queries"):
|
73 |
-
queries_texts[query.query_id] = query.text.strip()
|
74 |
-
print(f"Loaded {len(queries_texts)} queries from Cranfield corpus.")
|
75 |
-
|
76 |
-
# Load qrels
|
77 |
-
for qrel in tqdm(dataset.qrels_iter(), desc="Loading Cranfield qrels"):
|
78 |
-
if qrel.query_id not in qrels_data:
|
79 |
-
qrels_data[qrel.query_id] = []
|
80 |
-
qrels_data[qrel.query_id].append({"doc_id": qrel.doc_id, "relevance": qrel.relevance})
|
81 |
-
print(f"Loaded qrels for {len(qrels_data)} queries.")
|
82 |
-
|
83 |
-
except Exception as e:
|
84 |
-
print(f"Error loading Cranfield corpus with ir_datasets: {e}")
|
85 |
-
print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
|
86 |
-
|
87 |
-
|
88 |
-
# --- Helper function for lexical mask (now handles batches) ---
|
89 |
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
90 |
"""
|
91 |
Creates a batch of lexical BOW masks.
|
@@ -118,7 +79,7 @@ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
|
118 |
|
119 |
|
120 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
121 |
-
# These functions
|
122 |
def get_splade_cocondenser_representation(text):
|
123 |
if tokenizer_splade is None or model_splade is None:
|
124 |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
@@ -284,270 +245,10 @@ def predict_representation_explorer(model_choice, text):
|
|
284 |
return "Please select a model."
|
285 |
|
286 |
|
287 |
-
# --- Internal Core Representation Functions (now handle batches) ---
|
288 |
-
def get_splade_cocondenser_representation_internal(texts, tokenizer, model):
|
289 |
-
"""
|
290 |
-
Generates SPLADE representations for a batch of texts.
|
291 |
-
texts: list of strings
|
292 |
-
tokenizer: the tokenizer object
|
293 |
-
model: the SPLADE model
|
294 |
-
Returns: torch.Tensor of shape (batch_size, vocab_size) or None
|
295 |
-
"""
|
296 |
-
if tokenizer is None or model is None: return None
|
297 |
-
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
298 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
299 |
-
|
300 |
-
with torch.no_grad():
|
301 |
-
output = model(**inputs)
|
302 |
-
|
303 |
-
if hasattr(output, 'logits'):
|
304 |
-
# torch.max(..., dim=1)[0] reduces along sequence_length dimension,
|
305 |
-
# resulting in (batch_size, vocab_size)
|
306 |
-
splade_vectors = torch.max(
|
307 |
-
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
|
308 |
-
dim=1
|
309 |
-
)[0]
|
310 |
-
return splade_vectors
|
311 |
-
else:
|
312 |
-
print("Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.")
|
313 |
-
return None
|
314 |
-
|
315 |
-
def get_splade_lexical_representation_internal(texts, tokenizer, model):
|
316 |
-
"""
|
317 |
-
Generates SPLADE-Lexical representations for a batch of texts.
|
318 |
-
texts: list of strings
|
319 |
-
tokenizer: the tokenizer object
|
320 |
-
model: the SPLADE-Lexical model
|
321 |
-
Returns: torch.Tensor of shape (batch_size, vocab_size) or None
|
322 |
-
"""
|
323 |
-
if tokenizer is None or model is None: return None
|
324 |
-
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
325 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
326 |
-
with torch.no_grad(): output = model(**inputs)
|
327 |
-
if hasattr(output, 'logits'):
|
328 |
-
splade_vectors = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0]
|
329 |
-
vocab_size = tokenizer.vocab_size
|
330 |
-
# create_lexical_bow_mask now returns (batch_size, vocab_size)
|
331 |
-
bow_masks = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer)
|
332 |
-
splade_vectors = splade_vectors * bow_masks # Element-wise multiplication, shapes (batch_size, vocab_size)
|
333 |
-
return splade_vectors
|
334 |
-
else:
|
335 |
-
print("Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.")
|
336 |
-
return None
|
337 |
-
|
338 |
-
def get_splade_doc_representation_internal(texts, tokenizer, model):
|
339 |
-
"""
|
340 |
-
Generates SPLADE-Doc (binary) representations for a batch of texts.
|
341 |
-
texts: list of strings
|
342 |
-
tokenizer: the tokenizer object
|
343 |
-
model: the SPLADE-Doc model (not directly used for logits, but for device)
|
344 |
-
Returns: torch.Tensor of shape (batch_size, vocab_size) or None
|
345 |
-
"""
|
346 |
-
if tokenizer is None or model is None: return None
|
347 |
-
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
|
348 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()} # Ensure inputs are on the correct device
|
349 |
-
vocab_size = tokenizer.vocab_size
|
350 |
-
# create_lexical_bow_mask now returns (batch_size, vocab_size)
|
351 |
-
binary_splade_vectors = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer)
|
352 |
-
return binary_splade_vectors
|
353 |
-
|
354 |
-
|
355 |
-
# --- Document Indexing Function (now uses batching) ---
|
356 |
-
def index_documents(doc_model_choice):
|
357 |
-
global document_representations
|
358 |
-
if document_representations:
|
359 |
-
print("Documents already indexed. Skipping re-indexing.")
|
360 |
-
return True
|
361 |
-
|
362 |
-
tokenizer_to_use = None
|
363 |
-
model_to_use = None
|
364 |
-
representation_func_to_use = None
|
365 |
-
|
366 |
-
if doc_model_choice == "SPLADE-cocondenser-distil":
|
367 |
-
if tokenizer_splade is None or model_splade is None:
|
368 |
-
print("SPLADE-cocondenser-distil model not loaded for indexing.")
|
369 |
-
return False
|
370 |
-
tokenizer_to_use = tokenizer_splade
|
371 |
-
model_to_use = model_splade
|
372 |
-
representation_func_to_use = get_splade_cocondenser_representation_internal
|
373 |
-
elif doc_model_choice == "SPLADE-v3-Lexical":
|
374 |
-
if tokenizer_splade_lexical is None or model_splade_lexical is None:
|
375 |
-
print("SPLADE-v3-Lexical model not loaded for indexing.")
|
376 |
-
return False
|
377 |
-
tokenizer_to_use = tokenizer_splade_lexical
|
378 |
-
model_to_use = model_splade_lexical
|
379 |
-
representation_func_to_use = get_splade_lexical_representation_internal
|
380 |
-
elif doc_model_choice == "SPLADE-v3-Doc":
|
381 |
-
if tokenizer_splade_doc is None or model_splade_doc is None:
|
382 |
-
print("SPLADE-v3-Doc model not loaded for indexing.")
|
383 |
-
return False
|
384 |
-
tokenizer_to_use = tokenizer_splade_doc
|
385 |
-
model_to_use = model_splade_doc
|
386 |
-
representation_func_to_use = get_splade_doc_representation_internal
|
387 |
-
else:
|
388 |
-
print(f"Invalid model choice for document indexing: {doc_model_choice}")
|
389 |
-
return False
|
390 |
-
|
391 |
-
print(f"Indexing documents using {doc_model_choice}...")
|
392 |
-
|
393 |
-
doc_ids_list = list(document_texts.keys())
|
394 |
-
doc_texts_list = list(document_texts.values())
|
395 |
-
|
396 |
-
# --- BATCH SIZE FOR INDEXING ---
|
397 |
-
batch_size = 32 # You can adjust this value based on memory and performance
|
398 |
-
|
399 |
-
document_representations = {} # Ensure it's clear we're (re)building the index
|
400 |
-
|
401 |
-
# Iterate through documents in batches
|
402 |
-
for i in tqdm(range(0, len(doc_ids_list), batch_size), desc="Indexing Documents in Batches"):
|
403 |
-
batch_doc_ids = doc_ids_list[i:i + batch_size]
|
404 |
-
batch_doc_texts = doc_texts_list[i:i + batch_size]
|
405 |
-
|
406 |
-
sparse_vectors_batch = representation_func_to_use(batch_doc_texts, tokenizer_to_use, model_to_use)
|
407 |
-
|
408 |
-
if sparse_vectors_batch is not None:
|
409 |
-
# sparse_vectors_batch will have shape (batch_size, vocab_size)
|
410 |
-
for j, doc_id in enumerate(batch_doc_ids):
|
411 |
-
# Store each document's vector
|
412 |
-
document_representations[doc_id] = sparse_vectors_batch[j].cpu()
|
413 |
-
else:
|
414 |
-
print(f"Warning: Failed to get representation for a batch starting with doc_id {batch_doc_ids[0]}")
|
415 |
-
|
416 |
-
print(f"Finished indexing {len(document_representations)} documents.")
|
417 |
-
return True
|
418 |
-
|
419 |
-
# --- Retrieval Function (for Retrieval Tab) ---
|
420 |
-
def retrieve_documents(query_text, query_model_choice, indexed_doc_model_name, top_k=5):
|
421 |
-
if not document_representations:
|
422 |
-
return "Document index is not loaded or empty. Please ensure documents are indexed.", []
|
423 |
-
|
424 |
-
query_vector = None
|
425 |
-
query_tokenizer = None
|
426 |
-
query_model = None
|
427 |
-
|
428 |
-
# These internal calls still use single text input for the query
|
429 |
-
if query_model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
|
430 |
-
query_tokenizer = tokenizer_splade
|
431 |
-
query_model = model_splade
|
432 |
-
query_vector = get_splade_cocondenser_representation_internal([query_text], query_tokenizer, query_model)
|
433 |
-
elif query_model_choice == "SPLADE-v3-Lexical (weighting)":
|
434 |
-
query_tokenizer = tokenizer_splade_lexical
|
435 |
-
query_model = model_splade_lexical
|
436 |
-
query_vector = get_splade_lexical_representation_internal([query_text], query_tokenizer, query_model)
|
437 |
-
elif query_model_choice == "SPLADE-v3-Doc (binary)":
|
438 |
-
query_tokenizer = tokenizer_splade_doc
|
439 |
-
query_model = model_splade_doc
|
440 |
-
query_vector = get_splade_doc_representation_internal([query_text], query_tokenizer, query_model)
|
441 |
-
else:
|
442 |
-
return "Invalid query model choice.", []
|
443 |
-
|
444 |
-
if query_vector is None:
|
445 |
-
return "Failed to get query representation. Check console for model loading errors.", []
|
446 |
-
|
447 |
-
# Since internal functions now return batches, take the first (and only) item for single query
|
448 |
-
query_vector = query_vector.squeeze(0).cpu()
|
449 |
-
|
450 |
-
scores = {}
|
451 |
-
for doc_id, doc_vec in document_representations.items():
|
452 |
-
score = torch.dot(query_vector, doc_vec).item()
|
453 |
-
scores[doc_id] = score
|
454 |
-
|
455 |
-
sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
|
456 |
-
top_results = sorted_scores[:top_k]
|
457 |
-
|
458 |
-
formatted_output = f"Retrieval Results for Query: '{query_text}'\n"
|
459 |
-
formatted_output += f"Using Query Model: **{query_model_choice}**\n"
|
460 |
-
formatted_output += f"Documents Indexed with: **{indexed_doc_model_name}**\n\n"
|
461 |
-
|
462 |
-
if not top_results:
|
463 |
-
formatted_output += "No documents found or scored.\n"
|
464 |
-
else:
|
465 |
-
for i, (doc_id, score) in enumerate(top_results):
|
466 |
-
doc_text = document_texts.get(doc_id, "Document text not available.")
|
467 |
-
formatted_output += f"**{i+1}. Document ID: {doc_id}** (Score: {score:.4f})\n"
|
468 |
-
formatted_output += f"> {doc_text[:300]}...\n\n"
|
469 |
-
|
470 |
-
return formatted_output, top_results
|
471 |
-
|
472 |
-
# --- Unified Prediction Function for Gradio (for Retrieval Tab) ---
|
473 |
-
def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_display_only):
|
474 |
-
formatted_output, _ = retrieve_documents(query_text, query_model_choice, initial_doc_model_for_indexing, top_k=5)
|
475 |
-
return formatted_output
|
476 |
-
|
477 |
-
# --- New function to get specific retrieval examples ---
|
478 |
-
def get_specific_retrieval_examples():
|
479 |
-
if not queries_texts or not qrels_data or not document_texts:
|
480 |
-
return "Queries, qrels, or documents not loaded. Please check initial loading."
|
481 |
-
|
482 |
-
high_qrel_threshold = 3 # Relevance score of 3 or 4 for Cranfield is generally considered high
|
483 |
-
low_qrel_threshold = 1 # Relevance score of 0 or 1 for Cranfield is generally considered low
|
484 |
-
|
485 |
-
eligible_query_ids = []
|
486 |
-
for qid, qrels in qrels_data.items():
|
487 |
-
has_high_qrel = any(item['relevance'] >= high_qrel_threshold for item in qrels)
|
488 |
-
has_low_qrel = any(item['relevance'] <= low_qrel_threshold for item in qrels)
|
489 |
-
if has_high_qrel and has_low_qrel:
|
490 |
-
eligible_query_ids.append(qid)
|
491 |
-
|
492 |
-
if not eligible_query_ids:
|
493 |
-
return "Could not find a query with both high and low relevance documents in the loaded qrels."
|
494 |
-
|
495 |
-
# Pick a random eligible query
|
496 |
-
random_query_id = random.choice(eligible_query_ids)
|
497 |
-
full_query_text = queries_texts.get(random_query_id, "Query text not found.")
|
498 |
-
query_snippet = full_query_text[:300] + "..." if len(full_query_text) > 300 else full_query_text
|
499 |
-
|
500 |
-
qrels_for_query = qrels_data[random_query_id]
|
501 |
-
|
502 |
-
high_qrel_docs = [item for item in qrels_for_query if item['relevance'] >= high_qrel_threshold]
|
503 |
-
low_qrel_docs = [item for item in qrels_for_query if item['relevance'] <= low_qrel_threshold]
|
504 |
-
|
505 |
-
selected_high_doc_id = random.choice(high_qrel_docs)['doc_id'] if high_qrel_docs else None
|
506 |
-
selected_low_doc_id = random.choice(low_qrel_docs)['doc_id'] if low_qrel_docs else None
|
507 |
-
|
508 |
-
output_str = f"### Random Query Example\n\n"
|
509 |
-
output_str += f"**Query ID:** {random_query_id}\n"
|
510 |
-
output_str += f"**Query Snippet:** {query_snippet}\n\n" # Changed to snippet
|
511 |
-
|
512 |
-
if selected_high_doc_id:
|
513 |
-
full_doc_text = document_texts.get(selected_high_doc_id, "Document text not available.")
|
514 |
-
doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text
|
515 |
-
output_str += f"### Highly Relevant Document (Qrel >= {high_qrel_threshold})\n"
|
516 |
-
output_str += f"**Document ID:** {selected_high_doc_id}\n"
|
517 |
-
output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet
|
518 |
-
else:
|
519 |
-
output_str += "No highly relevant document found for this query.\n\n"
|
520 |
-
|
521 |
-
if selected_low_doc_id:
|
522 |
-
full_doc_text = document_texts.get(selected_low_doc_id, "Document text not available.")
|
523 |
-
doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text
|
524 |
-
output_str += f"### Lowly Relevant Document (Qrel <= {low_qrel_threshold})\n"
|
525 |
-
output_str += f"**Document ID:** {selected_low_doc_id}\n"
|
526 |
-
output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet
|
527 |
-
else:
|
528 |
-
output_str += "No lowly relevant document found for this query.\n\n"
|
529 |
-
|
530 |
-
return output_str
|
531 |
-
|
532 |
-
|
533 |
-
# --- Initial Load and Indexing Calls ---
|
534 |
-
# This part runs once when the app starts.
|
535 |
-
load_cranfield_corpus_ir_datasets()
|
536 |
-
|
537 |
-
if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
|
538 |
-
index_documents(initial_doc_model_for_indexing)
|
539 |
-
elif initial_doc_model_for_indexing == "SPLADE-v3-Lexical" and model_splade_lexical is not None:
|
540 |
-
index_documents(initial_doc_model_for_indexing)
|
541 |
-
elif initial_doc_model_for_indexing == "SPLADE-v3-Doc" and model_splade_doc is not None:
|
542 |
-
index_documents(initial_doc_model_for_indexing)
|
543 |
-
else:
|
544 |
-
print(f"Skipping document indexing: Model '{initial_doc_model_for_indexing}' failed to load or is not a valid choice for indexing.")
|
545 |
-
|
546 |
-
|
547 |
# --- Gradio Interface Setup with Tabs ---
|
548 |
with gr.Blocks(title="SPLADE Demos") as demo:
|
549 |
-
gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer
|
550 |
-
gr.Markdown("Explore different SPLADE models and their sparse representation types
|
551 |
|
552 |
with gr.Tabs():
|
553 |
with gr.TabItem("Sparse Representation Explorer"):
|
@@ -575,49 +276,4 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
575 |
# live=True # Setting live=True might be slow for complex models on every keystroke
|
576 |
)
|
577 |
|
578 |
-
with gr.TabItem("Document Retrieval Demo"):
|
579 |
-
gr.Markdown("### Retrieve Documents from Cranfield Collection")
|
580 |
-
gr.Interface(
|
581 |
-
fn=predict_retrieval_gradio,
|
582 |
-
inputs=[
|
583 |
-
gr.Textbox(
|
584 |
-
lines=3,
|
585 |
-
label="Enter your query text here:",
|
586 |
-
placeholder="e.g., Does high-dose vitamin C cure cancer?"
|
587 |
-
),
|
588 |
-
gr.Radio(
|
589 |
-
[
|
590 |
-
"SPLADE-cocondenser-distil (weighting and expansion)",
|
591 |
-
"SPLADE-v3-Lexical (weighting)",
|
592 |
-
"SPLADE-v3-Doc (binary)"
|
593 |
-
],
|
594 |
-
label="Choose Query Representation Model",
|
595 |
-
value="SPLADE-cocondenser-distil (weighting and expansion)"
|
596 |
-
),
|
597 |
-
gr.Radio(
|
598 |
-
[
|
599 |
-
"SPLADE-cocondenser-distil",
|
600 |
-
"SPLADE-v3-Lexical",
|
601 |
-
"SPLADE-v3-Doc"
|
602 |
-
],
|
603 |
-
label=f"Document Index Model (Pre-indexed with: {initial_doc_model_for_indexing})",
|
604 |
-
value=initial_doc_model_for_indexing,
|
605 |
-
interactive=False # This radio is fixed for simplicity
|
606 |
-
)
|
607 |
-
],
|
608 |
-
outputs=gr.Markdown(),
|
609 |
-
allow_flagging="never",
|
610 |
-
# live=True # retrieval is too heavy for live
|
611 |
-
)
|
612 |
-
|
613 |
-
gr.Markdown("---") # Separator
|
614 |
-
gr.Markdown("### Get Specific Retrieval Examples")
|
615 |
-
specific_example_output = gr.Markdown()
|
616 |
-
specific_example_button = gr.Button("Get Random Query with High/Low Qrel Docs")
|
617 |
-
specific_example_button.click(
|
618 |
-
fn=get_specific_retrieval_examples,
|
619 |
-
inputs=[],
|
620 |
-
outputs=specific_example_output
|
621 |
-
)
|
622 |
-
|
623 |
demo.launch()
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
+
from tqdm.auto import tqdm # Still useful for model loading progress if desired, but not strictly necessary for this simplified version
|
6 |
+
import os # Still useful for general purpose, but not explicitly used in this simplified version
|
|
|
|
|
7 |
|
8 |
+
# --- Model Loading ---
|
9 |
tokenizer_splade = None
|
10 |
model_splade = None
|
11 |
tokenizer_splade_lexical = None
|
|
|
46 |
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
|
47 |
|
48 |
|
49 |
+
# --- Helper function for lexical mask (now handles batches, but used for single input here) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
|
51 |
"""
|
52 |
Creates a batch of lexical BOW masks.
|
|
|
79 |
|
80 |
|
81 |
# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
|
82 |
+
# These functions take single text input for the Explorer tab
|
83 |
def get_splade_cocondenser_representation(text):
|
84 |
if tokenizer_splade is None or model_splade is None:
|
85 |
return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
|
|
|
245 |
return "Please select a model."
|
246 |
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
# --- Gradio Interface Setup with Tabs ---
|
249 |
with gr.Blocks(title="SPLADE Demos") as demo:
|
250 |
+
gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer") # Updated title
|
251 |
+
gr.Markdown("Explore different SPLADE models and their sparse representation types.") # Updated description
|
252 |
|
253 |
with gr.Tabs():
|
254 |
with gr.TabItem("Sparse Representation Explorer"):
|
|
|
276 |
# live=True # Setting live=True might be slow for complex models on every keystroke
|
277 |
)
|
278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
demo.launch()
|