SiddharthAK commited on
Commit
4e0cddb
·
verified ·
1 Parent(s): b0796be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -351
app.py CHANGED
@@ -2,12 +2,10 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForMaskedLM
3
  import torch
4
  import numpy as np
5
- from tqdm.auto import tqdm
6
- import os
7
- import ir_datasets
8
- import random # Added for random selection
9
 
10
- # --- Model Loading (Keep as is) ---
11
  tokenizer_splade = None
12
  model_splade = None
13
  tokenizer_splade_lexical = None
@@ -48,44 +46,7 @@ except Exception as e:
48
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
49
 
50
 
51
- # --- Global Variables for Document Index and Qrels ---
52
- document_representations = {} # Stores {doc_id: sparse_vector}
53
- document_texts = {} # Stores {doc_id: doc_text}
54
- queries_texts = {} # Stores {query_id: query_text}
55
- qrels_data = {} # Stores {query_id: [{doc_id: str, relevance: int}, ...]}
56
- initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
57
-
58
-
59
- # --- Load Cranfield Corpus, Queries, and Qrels using ir_datasets ---
60
- def load_cranfield_corpus_ir_datasets():
61
- global document_texts, queries_texts, qrels_data
62
- print("Loading Cranfield corpus, queries, and qrels using ir_datasets...")
63
- try:
64
- dataset = ir_datasets.load("cranfield")
65
-
66
- # Load documents
67
- for doc in tqdm(dataset.docs_iter(), desc="Loading Cranfield documents"):
68
- document_texts[doc.doc_id] = doc.text.strip()
69
- print(f"Loaded {len(document_texts)} documents from Cranfield corpus.")
70
-
71
- # Load queries
72
- for query in tqdm(dataset.queries_iter(), desc="Loading Cranfield queries"):
73
- queries_texts[query.query_id] = query.text.strip()
74
- print(f"Loaded {len(queries_texts)} queries from Cranfield corpus.")
75
-
76
- # Load qrels
77
- for qrel in tqdm(dataset.qrels_iter(), desc="Loading Cranfield qrels"):
78
- if qrel.query_id not in qrels_data:
79
- qrels_data[qrel.query_id] = []
80
- qrels_data[qrel.query_id].append({"doc_id": qrel.doc_id, "relevance": qrel.relevance})
81
- print(f"Loaded qrels for {len(qrels_data)} queries.")
82
-
83
- except Exception as e:
84
- print(f"Error loading Cranfield corpus with ir_datasets: {e}")
85
- print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
86
-
87
-
88
- # --- Helper function for lexical mask (now handles batches) ---
89
  def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
90
  """
91
  Creates a batch of lexical BOW masks.
@@ -118,7 +79,7 @@ def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
118
 
119
 
120
  # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
121
- # These functions still take single text input for the Explorer tab
122
  def get_splade_cocondenser_representation(text):
123
  if tokenizer_splade is None or model_splade is None:
124
  return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
@@ -284,270 +245,10 @@ def predict_representation_explorer(model_choice, text):
284
  return "Please select a model."
285
 
286
 
287
- # --- Internal Core Representation Functions (now handle batches) ---
288
- def get_splade_cocondenser_representation_internal(texts, tokenizer, model):
289
- """
290
- Generates SPLADE representations for a batch of texts.
291
- texts: list of strings
292
- tokenizer: the tokenizer object
293
- model: the SPLADE model
294
- Returns: torch.Tensor of shape (batch_size, vocab_size) or None
295
- """
296
- if tokenizer is None or model is None: return None
297
- inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
298
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
299
-
300
- with torch.no_grad():
301
- output = model(**inputs)
302
-
303
- if hasattr(output, 'logits'):
304
- # torch.max(..., dim=1)[0] reduces along sequence_length dimension,
305
- # resulting in (batch_size, vocab_size)
306
- splade_vectors = torch.max(
307
- torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
308
- dim=1
309
- )[0]
310
- return splade_vectors
311
- else:
312
- print("Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.")
313
- return None
314
-
315
- def get_splade_lexical_representation_internal(texts, tokenizer, model):
316
- """
317
- Generates SPLADE-Lexical representations for a batch of texts.
318
- texts: list of strings
319
- tokenizer: the tokenizer object
320
- model: the SPLADE-Lexical model
321
- Returns: torch.Tensor of shape (batch_size, vocab_size) or None
322
- """
323
- if tokenizer is None or model is None: return None
324
- inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
325
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
326
- with torch.no_grad(): output = model(**inputs)
327
- if hasattr(output, 'logits'):
328
- splade_vectors = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0]
329
- vocab_size = tokenizer.vocab_size
330
- # create_lexical_bow_mask now returns (batch_size, vocab_size)
331
- bow_masks = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer)
332
- splade_vectors = splade_vectors * bow_masks # Element-wise multiplication, shapes (batch_size, vocab_size)
333
- return splade_vectors
334
- else:
335
- print("Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.")
336
- return None
337
-
338
- def get_splade_doc_representation_internal(texts, tokenizer, model):
339
- """
340
- Generates SPLADE-Doc (binary) representations for a batch of texts.
341
- texts: list of strings
342
- tokenizer: the tokenizer object
343
- model: the SPLADE-Doc model (not directly used for logits, but for device)
344
- Returns: torch.Tensor of shape (batch_size, vocab_size) or None
345
- """
346
- if tokenizer is None or model is None: return None
347
- inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
348
- inputs = {k: v.to(model.device) for k, v in inputs.items()} # Ensure inputs are on the correct device
349
- vocab_size = tokenizer.vocab_size
350
- # create_lexical_bow_mask now returns (batch_size, vocab_size)
351
- binary_splade_vectors = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer)
352
- return binary_splade_vectors
353
-
354
-
355
- # --- Document Indexing Function (now uses batching) ---
356
- def index_documents(doc_model_choice):
357
- global document_representations
358
- if document_representations:
359
- print("Documents already indexed. Skipping re-indexing.")
360
- return True
361
-
362
- tokenizer_to_use = None
363
- model_to_use = None
364
- representation_func_to_use = None
365
-
366
- if doc_model_choice == "SPLADE-cocondenser-distil":
367
- if tokenizer_splade is None or model_splade is None:
368
- print("SPLADE-cocondenser-distil model not loaded for indexing.")
369
- return False
370
- tokenizer_to_use = tokenizer_splade
371
- model_to_use = model_splade
372
- representation_func_to_use = get_splade_cocondenser_representation_internal
373
- elif doc_model_choice == "SPLADE-v3-Lexical":
374
- if tokenizer_splade_lexical is None or model_splade_lexical is None:
375
- print("SPLADE-v3-Lexical model not loaded for indexing.")
376
- return False
377
- tokenizer_to_use = tokenizer_splade_lexical
378
- model_to_use = model_splade_lexical
379
- representation_func_to_use = get_splade_lexical_representation_internal
380
- elif doc_model_choice == "SPLADE-v3-Doc":
381
- if tokenizer_splade_doc is None or model_splade_doc is None:
382
- print("SPLADE-v3-Doc model not loaded for indexing.")
383
- return False
384
- tokenizer_to_use = tokenizer_splade_doc
385
- model_to_use = model_splade_doc
386
- representation_func_to_use = get_splade_doc_representation_internal
387
- else:
388
- print(f"Invalid model choice for document indexing: {doc_model_choice}")
389
- return False
390
-
391
- print(f"Indexing documents using {doc_model_choice}...")
392
-
393
- doc_ids_list = list(document_texts.keys())
394
- doc_texts_list = list(document_texts.values())
395
-
396
- # --- BATCH SIZE FOR INDEXING ---
397
- batch_size = 32 # You can adjust this value based on memory and performance
398
-
399
- document_representations = {} # Ensure it's clear we're (re)building the index
400
-
401
- # Iterate through documents in batches
402
- for i in tqdm(range(0, len(doc_ids_list), batch_size), desc="Indexing Documents in Batches"):
403
- batch_doc_ids = doc_ids_list[i:i + batch_size]
404
- batch_doc_texts = doc_texts_list[i:i + batch_size]
405
-
406
- sparse_vectors_batch = representation_func_to_use(batch_doc_texts, tokenizer_to_use, model_to_use)
407
-
408
- if sparse_vectors_batch is not None:
409
- # sparse_vectors_batch will have shape (batch_size, vocab_size)
410
- for j, doc_id in enumerate(batch_doc_ids):
411
- # Store each document's vector
412
- document_representations[doc_id] = sparse_vectors_batch[j].cpu()
413
- else:
414
- print(f"Warning: Failed to get representation for a batch starting with doc_id {batch_doc_ids[0]}")
415
-
416
- print(f"Finished indexing {len(document_representations)} documents.")
417
- return True
418
-
419
- # --- Retrieval Function (for Retrieval Tab) ---
420
- def retrieve_documents(query_text, query_model_choice, indexed_doc_model_name, top_k=5):
421
- if not document_representations:
422
- return "Document index is not loaded or empty. Please ensure documents are indexed.", []
423
-
424
- query_vector = None
425
- query_tokenizer = None
426
- query_model = None
427
-
428
- # These internal calls still use single text input for the query
429
- if query_model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
430
- query_tokenizer = tokenizer_splade
431
- query_model = model_splade
432
- query_vector = get_splade_cocondenser_representation_internal([query_text], query_tokenizer, query_model)
433
- elif query_model_choice == "SPLADE-v3-Lexical (weighting)":
434
- query_tokenizer = tokenizer_splade_lexical
435
- query_model = model_splade_lexical
436
- query_vector = get_splade_lexical_representation_internal([query_text], query_tokenizer, query_model)
437
- elif query_model_choice == "SPLADE-v3-Doc (binary)":
438
- query_tokenizer = tokenizer_splade_doc
439
- query_model = model_splade_doc
440
- query_vector = get_splade_doc_representation_internal([query_text], query_tokenizer, query_model)
441
- else:
442
- return "Invalid query model choice.", []
443
-
444
- if query_vector is None:
445
- return "Failed to get query representation. Check console for model loading errors.", []
446
-
447
- # Since internal functions now return batches, take the first (and only) item for single query
448
- query_vector = query_vector.squeeze(0).cpu()
449
-
450
- scores = {}
451
- for doc_id, doc_vec in document_representations.items():
452
- score = torch.dot(query_vector, doc_vec).item()
453
- scores[doc_id] = score
454
-
455
- sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
456
- top_results = sorted_scores[:top_k]
457
-
458
- formatted_output = f"Retrieval Results for Query: '{query_text}'\n"
459
- formatted_output += f"Using Query Model: **{query_model_choice}**\n"
460
- formatted_output += f"Documents Indexed with: **{indexed_doc_model_name}**\n\n"
461
-
462
- if not top_results:
463
- formatted_output += "No documents found or scored.\n"
464
- else:
465
- for i, (doc_id, score) in enumerate(top_results):
466
- doc_text = document_texts.get(doc_id, "Document text not available.")
467
- formatted_output += f"**{i+1}. Document ID: {doc_id}** (Score: {score:.4f})\n"
468
- formatted_output += f"> {doc_text[:300]}...\n\n"
469
-
470
- return formatted_output, top_results
471
-
472
- # --- Unified Prediction Function for Gradio (for Retrieval Tab) ---
473
- def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_display_only):
474
- formatted_output, _ = retrieve_documents(query_text, query_model_choice, initial_doc_model_for_indexing, top_k=5)
475
- return formatted_output
476
-
477
- # --- New function to get specific retrieval examples ---
478
- def get_specific_retrieval_examples():
479
- if not queries_texts or not qrels_data or not document_texts:
480
- return "Queries, qrels, or documents not loaded. Please check initial loading."
481
-
482
- high_qrel_threshold = 3 # Relevance score of 3 or 4 for Cranfield is generally considered high
483
- low_qrel_threshold = 1 # Relevance score of 0 or 1 for Cranfield is generally considered low
484
-
485
- eligible_query_ids = []
486
- for qid, qrels in qrels_data.items():
487
- has_high_qrel = any(item['relevance'] >= high_qrel_threshold for item in qrels)
488
- has_low_qrel = any(item['relevance'] <= low_qrel_threshold for item in qrels)
489
- if has_high_qrel and has_low_qrel:
490
- eligible_query_ids.append(qid)
491
-
492
- if not eligible_query_ids:
493
- return "Could not find a query with both high and low relevance documents in the loaded qrels."
494
-
495
- # Pick a random eligible query
496
- random_query_id = random.choice(eligible_query_ids)
497
- full_query_text = queries_texts.get(random_query_id, "Query text not found.")
498
- query_snippet = full_query_text[:300] + "..." if len(full_query_text) > 300 else full_query_text
499
-
500
- qrels_for_query = qrels_data[random_query_id]
501
-
502
- high_qrel_docs = [item for item in qrels_for_query if item['relevance'] >= high_qrel_threshold]
503
- low_qrel_docs = [item for item in qrels_for_query if item['relevance'] <= low_qrel_threshold]
504
-
505
- selected_high_doc_id = random.choice(high_qrel_docs)['doc_id'] if high_qrel_docs else None
506
- selected_low_doc_id = random.choice(low_qrel_docs)['doc_id'] if low_qrel_docs else None
507
-
508
- output_str = f"### Random Query Example\n\n"
509
- output_str += f"**Query ID:** {random_query_id}\n"
510
- output_str += f"**Query Snippet:** {query_snippet}\n\n" # Changed to snippet
511
-
512
- if selected_high_doc_id:
513
- full_doc_text = document_texts.get(selected_high_doc_id, "Document text not available.")
514
- doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text
515
- output_str += f"### Highly Relevant Document (Qrel >= {high_qrel_threshold})\n"
516
- output_str += f"**Document ID:** {selected_high_doc_id}\n"
517
- output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet
518
- else:
519
- output_str += "No highly relevant document found for this query.\n\n"
520
-
521
- if selected_low_doc_id:
522
- full_doc_text = document_texts.get(selected_low_doc_id, "Document text not available.")
523
- doc_snippet = full_doc_text[:500] + "..." if len(full_doc_text) > 500 else full_doc_text
524
- output_str += f"### Lowly Relevant Document (Qrel <= {low_qrel_threshold})\n"
525
- output_str += f"**Document ID:** {selected_low_doc_id}\n"
526
- output_str += f"**Document Snippet:** {doc_snippet}\n\n" # Changed to snippet
527
- else:
528
- output_str += "No lowly relevant document found for this query.\n\n"
529
-
530
- return output_str
531
-
532
-
533
- # --- Initial Load and Indexing Calls ---
534
- # This part runs once when the app starts.
535
- load_cranfield_corpus_ir_datasets()
536
-
537
- if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
538
- index_documents(initial_doc_model_for_indexing)
539
- elif initial_doc_model_for_indexing == "SPLADE-v3-Lexical" and model_splade_lexical is not None:
540
- index_documents(initial_doc_model_for_indexing)
541
- elif initial_doc_model_for_indexing == "SPLADE-v3-Doc" and model_splade_doc is not None:
542
- index_documents(initial_doc_model_for_indexing)
543
- else:
544
- print(f"Skipping document indexing: Model '{initial_doc_model_for_indexing}' failed to load or is not a valid choice for indexing.")
545
-
546
-
547
  # --- Gradio Interface Setup with Tabs ---
548
  with gr.Blocks(title="SPLADE Demos") as demo:
549
- gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer & Document Retrieval")
550
- gr.Markdown("Explore different SPLADE models and their sparse representation types, or perform document retrieval on a test collection.")
551
 
552
  with gr.Tabs():
553
  with gr.TabItem("Sparse Representation Explorer"):
@@ -575,49 +276,4 @@ with gr.Blocks(title="SPLADE Demos") as demo:
575
  # live=True # Setting live=True might be slow for complex models on every keystroke
576
  )
577
 
578
- with gr.TabItem("Document Retrieval Demo"):
579
- gr.Markdown("### Retrieve Documents from Cranfield Collection")
580
- gr.Interface(
581
- fn=predict_retrieval_gradio,
582
- inputs=[
583
- gr.Textbox(
584
- lines=3,
585
- label="Enter your query text here:",
586
- placeholder="e.g., Does high-dose vitamin C cure cancer?"
587
- ),
588
- gr.Radio(
589
- [
590
- "SPLADE-cocondenser-distil (weighting and expansion)",
591
- "SPLADE-v3-Lexical (weighting)",
592
- "SPLADE-v3-Doc (binary)"
593
- ],
594
- label="Choose Query Representation Model",
595
- value="SPLADE-cocondenser-distil (weighting and expansion)"
596
- ),
597
- gr.Radio(
598
- [
599
- "SPLADE-cocondenser-distil",
600
- "SPLADE-v3-Lexical",
601
- "SPLADE-v3-Doc"
602
- ],
603
- label=f"Document Index Model (Pre-indexed with: {initial_doc_model_for_indexing})",
604
- value=initial_doc_model_for_indexing,
605
- interactive=False # This radio is fixed for simplicity
606
- )
607
- ],
608
- outputs=gr.Markdown(),
609
- allow_flagging="never",
610
- # live=True # retrieval is too heavy for live
611
- )
612
-
613
- gr.Markdown("---") # Separator
614
- gr.Markdown("### Get Specific Retrieval Examples")
615
- specific_example_output = gr.Markdown()
616
- specific_example_button = gr.Button("Get Random Query with High/Low Qrel Docs")
617
- specific_example_button.click(
618
- fn=get_specific_retrieval_examples,
619
- inputs=[],
620
- outputs=specific_example_output
621
- )
622
-
623
  demo.launch()
 
2
  from transformers import AutoTokenizer, AutoModelForMaskedLM
3
  import torch
4
  import numpy as np
5
+ from tqdm.auto import tqdm # Still useful for model loading progress if desired, but not strictly necessary for this simplified version
6
+ import os # Still useful for general purpose, but not explicitly used in this simplified version
 
 
7
 
8
+ # --- Model Loading ---
9
  tokenizer_splade = None
10
  model_splade = None
11
  tokenizer_splade_lexical = None
 
46
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
47
 
48
 
49
+ # --- Helper function for lexical mask (now handles batches, but used for single input here) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
51
  """
52
  Creates a batch of lexical BOW masks.
 
79
 
80
 
81
  # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
82
+ # These functions take single text input for the Explorer tab
83
  def get_splade_cocondenser_representation(text):
84
  if tokenizer_splade is None or model_splade is None:
85
  return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors."
 
245
  return "Please select a model."
246
 
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  # --- Gradio Interface Setup with Tabs ---
249
  with gr.Blocks(title="SPLADE Demos") as demo:
250
+ gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer") # Updated title
251
+ gr.Markdown("Explore different SPLADE models and their sparse representation types.") # Updated description
252
 
253
  with gr.Tabs():
254
  with gr.TabItem("Sparse Representation Explorer"):
 
276
  # live=True # Setting live=True might be slow for complex models on every keystroke
277
  )
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  demo.launch()