SiddharthAK commited on
Commit
01b1a90
·
verified ·
1 Parent(s): e05e8c0

added retrieval feature

Browse files
Files changed (1) hide show
  1. app.py +269 -52
app.py CHANGED
@@ -1,8 +1,12 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
3
  import torch
 
 
 
 
4
 
5
- # --- Model Loading ---
6
  tokenizer_splade = None
7
  model_splade = None
8
  tokenizer_splade_lexical = None
@@ -14,7 +18,7 @@ model_splade_doc = None
14
  try:
15
  tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
16
  model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
17
- model_splade.eval() # Set to evaluation mode for inference
18
  print("SPLADE-cocondenser-distil model loaded successfully!")
19
  except Exception as e:
20
  print(f"Error loading SPLADE-cocondenser-distil model: {e}")
@@ -25,7 +29,7 @@ try:
25
  splade_lexical_model_name = "naver/splade-v3-lexical"
26
  tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
27
  model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
28
- model_splade_lexical.eval() # Set to evaluation mode for inference
29
  print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
30
  except Exception as e:
31
  print(f"Error loading SPLADE-v3-Lexical model: {e}")
@@ -36,19 +40,35 @@ try:
36
  splade_doc_model_name = "naver/splade-v3-doc"
37
  tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
38
  model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
39
- model_splade_doc.eval() # Set to evaluation mode for inference
40
  print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
41
  except Exception as e:
42
  print(f"Error loading SPLADE-v3-Doc model: {e}")
43
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
44
 
45
 
46
- # --- Helper function for lexical mask ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
48
- """
49
- Creates a binary bag-of-words mask from input_ids,
50
- zeroing out special tokens and padding.
51
- """
52
  bow_mask = torch.zeros(vocab_size, device=input_ids.device)
53
  meaningful_token_ids = []
54
  for token_id in input_ids.squeeze().tolist():
@@ -60,14 +80,15 @@ def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
60
  tokenizer.unk_token_id
61
  ]:
62
  meaningful_token_ids.append(token_id)
63
-
64
  if meaningful_token_ids:
65
  bow_mask[list(set(meaningful_token_ids))] = 1
66
 
67
  return bow_mask.unsqueeze(0)
68
 
69
 
70
- # --- Core Representation Functions ---
 
71
 
72
  def get_splade_cocondenser_representation(text):
73
  if tokenizer_splade is None or model_splade is None:
@@ -80,7 +101,6 @@ def get_splade_cocondenser_representation(text):
80
  output = model_splade(**inputs)
81
 
82
  if hasattr(output, 'logits'):
83
- # Standard SPLADE calculation for learned weighting and expansion
84
  splade_vector = torch.max(
85
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
86
  dim=1
@@ -90,7 +110,7 @@ def get_splade_cocondenser_representation(text):
90
 
91
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
92
  if not isinstance(indices, list):
93
- indices = [indices]
94
 
95
  values = splade_vector[indices].cpu().tolist()
96
  token_weights = dict(zip(indices, values))
@@ -139,12 +159,12 @@ def get_splade_lexical_representation(text):
139
  vocab_size = tokenizer_splade_lexical.vocab_size
140
  bow_mask = create_lexical_bow_mask(
141
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
142
- ).squeeze()
143
  splade_vector = splade_vector * bow_mask
144
 
145
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
146
  if not isinstance(indices, list):
147
- indices = [indices]
148
 
149
  values = splade_vector[indices].cpu().tolist()
150
  token_weights = dict(zip(indices, values))
@@ -171,7 +191,6 @@ def get_splade_lexical_representation(text):
171
  return formatted_output
172
 
173
 
174
- # Function for SPLADE-v3-Doc representation (Binary Sparse - Lexical Only)
175
  def get_splade_doc_representation(text):
176
  if tokenizer_splade_doc is None or model_splade_doc is None:
177
  return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors."
@@ -185,19 +204,15 @@ def get_splade_doc_representation(text):
185
  if not hasattr(output, "logits"):
186
  return "SPLADE-v3-Doc model output structure not as expected. 'logits' not found."
187
 
188
- # For SPLADE-v3-Doc, assuming output is designed to be binary and lexical-only.
189
- # We will derive the output directly from the input tokens themselves,
190
- # as the model's primary role in this context is as a pre-trained LM feature extractor
191
- # for a document-side, lexical-only binary sparse representation.
192
  vocab_size = tokenizer_splade_doc.vocab_size
193
- binary_splade_vector = create_lexical_bow_mask( # Use the BOW mask directly for binary
194
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
195
  ).squeeze()
196
 
197
  indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
198
  if not isinstance(indices, list):
199
  indices = [indices] if indices else []
200
-
201
  values = [1.0] * len(indices) # All values are 1 for binary representation
202
  token_weights = dict(zip(indices, values))
203
 
@@ -226,41 +241,243 @@ def get_splade_doc_representation(text):
226
  return formatted_output
227
 
228
 
229
- # --- Unified Prediction Function for Gradio ---
230
- def predict_representation(model_choice, text):
231
  if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
232
  return get_splade_cocondenser_representation(text)
233
  elif model_choice == "SPLADE-v3-Lexical (weighting)":
234
  return get_splade_lexical_representation(text)
235
  elif model_choice == "SPLADE-v3-Doc (binary)":
236
- return get_splade_doc_representation(text)
237
  else:
238
  return "Please select a model."
239
 
240
- # --- Gradio Interface Setup ---
241
- demo = gr.Interface(
242
- fn=predict_representation,
243
- inputs=[
244
- gr.Radio(
245
- [
246
- "SPLADE-cocondenser-distil (weighting and expansion)",
247
- "SPLADE-v3-Lexical (weighting)",
248
- "SPLADE-v3-Doc (binary)"
249
- ],
250
- label="Choose Representation Model",
251
- value="SPLADE-cocondenser-distil (weighting and expansion)" # Corrected default value
252
- ),
253
- gr.Textbox(
254
- lines=5,
255
- label="Enter your query or document text here:",
256
- placeholder="e.g., Why is Padua the nicest city in Italy?"
257
- )
258
- ],
259
- outputs=gr.Markdown(),
260
- title="🌌 Sparse Representation Generator",
261
- description="Explore different SPLADE models and their sparse representation types: weighted and expansive, weighted and lexical-only, or strictly binary.",
262
- allow_flagging="never"
263
- )
264
-
265
- # Launch the Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
  import torch
4
+ import numpy as np
5
+ from tqdm.auto import tqdm
6
+ import os
7
+ import ir_datasets
8
 
9
+ # --- Model Loading (Keep as is) ---
10
  tokenizer_splade = None
11
  model_splade = None
12
  tokenizer_splade_lexical = None
 
18
  try:
19
  tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
20
  model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
21
+ model_splade.eval()
22
  print("SPLADE-cocondenser-distil model loaded successfully!")
23
  except Exception as e:
24
  print(f"Error loading SPLADE-cocondenser-distil model: {e}")
 
29
  splade_lexical_model_name = "naver/splade-v3-lexical"
30
  tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
31
  model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
32
+ model_splade_lexical.eval()
33
  print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
34
  except Exception as e:
35
  print(f"Error loading SPLADE-v3-Lexical model: {e}")
 
40
  splade_doc_model_name = "naver/splade-v3-doc"
41
  tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
42
  model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
43
+ model_splade_doc.eval()
44
  print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
45
  except Exception as e:
46
  print(f"Error loading SPLADE-v3-Doc model: {e}")
47
  print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
48
 
49
 
50
+ # --- Global Variables for Document Index ---
51
+ document_representations = {} # Stores {doc_id: sparse_vector}
52
+ document_texts = {} # Stores {doc_id: doc_text}
53
+ initial_doc_model_for_indexing = "SPLADE-cocondenser-distil" # Fixed for initial demo index
54
+
55
+
56
+ # --- Load SciFact Corpus using ir_datasets ---
57
+ def load_scifact_corpus_ir_datasets():
58
+ global document_texts
59
+ print("Loading SciFact corpus using ir_datasets...")
60
+ try:
61
+ dataset = ir_datasets.load("scifact")
62
+ for doc in tqdm(dataset.docs_iter(), desc="Loading SciFact documents"):
63
+ document_texts[doc.doc_id] = doc.text.strip()
64
+ print(f"Loaded {len(document_texts)} documents from SciFact corpus.")
65
+ except Exception as e:
66
+ print(f"Error loading SciFact corpus with ir_datasets: {e}")
67
+ print("Please ensure 'ir_datasets' is installed and your internet connection is stable.")
68
+
69
+
70
+ # --- Helper function for lexical mask (Keep as is) ---
71
  def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
 
 
 
 
72
  bow_mask = torch.zeros(vocab_size, device=input_ids.device)
73
  meaningful_token_ids = []
74
  for token_id in input_ids.squeeze().tolist():
 
80
  tokenizer.unk_token_id
81
  ]:
82
  meaningful_token_ids.append(token_id)
83
+
84
  if meaningful_token_ids:
85
  bow_mask[list(set(meaningful_token_ids))] = 1
86
 
87
  return bow_mask.unsqueeze(0)
88
 
89
 
90
+ # --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
91
+ # These are your original functions, re-added.
92
 
93
  def get_splade_cocondenser_representation(text):
94
  if tokenizer_splade is None or model_splade is None:
 
101
  output = model_splade(**inputs)
102
 
103
  if hasattr(output, 'logits'):
 
104
  splade_vector = torch.max(
105
  torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
106
  dim=1
 
110
 
111
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
112
  if not isinstance(indices, list):
113
+ indices = [indices] if indices else []
114
 
115
  values = splade_vector[indices].cpu().tolist()
116
  token_weights = dict(zip(indices, values))
 
159
  vocab_size = tokenizer_splade_lexical.vocab_size
160
  bow_mask = create_lexical_bow_mask(
161
  inputs['input_ids'], vocab_size, tokenizer_splade_lexical
162
+ ).squeeze()
163
  splade_vector = splade_vector * bow_mask
164
 
165
  indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
166
  if not isinstance(indices, list):
167
+ indices = [indices] if indices else []
168
 
169
  values = splade_vector[indices].cpu().tolist()
170
  token_weights = dict(zip(indices, values))
 
191
  return formatted_output
192
 
193
 
 
194
  def get_splade_doc_representation(text):
195
  if tokenizer_splade_doc is None or model_splade_doc is None:
196
  return "SPLADE-v3-Doc model is not loaded. Please check the console for loading errors."
 
204
  if not hasattr(output, "logits"):
205
  return "SPLADE-v3-Doc model output structure not as expected. 'logits' not found."
206
 
 
 
 
 
207
  vocab_size = tokenizer_splade_doc.vocab_size
208
+ binary_splade_vector = create_lexical_bow_mask(
209
  inputs['input_ids'], vocab_size, tokenizer_splade_doc
210
  ).squeeze()
211
 
212
  indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
213
  if not isinstance(indices, list):
214
  indices = [indices] if indices else []
215
+
216
  values = [1.0] * len(indices) # All values are 1 for binary representation
217
  token_weights = dict(zip(indices, values))
218
 
 
241
  return formatted_output
242
 
243
 
244
+ # --- Unified Prediction Function for the Explorer Tab ---
245
+ def predict_representation_explorer(model_choice, text):
246
  if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
247
  return get_splade_cocondenser_representation(text)
248
  elif model_choice == "SPLADE-v3-Lexical (weighting)":
249
  return get_splade_lexical_representation(text)
250
  elif model_choice == "SPLADE-v3-Doc (binary)":
251
+ return get_splade_doc_representation(text)
252
  else:
253
  return "Please select a model."
254
 
255
+
256
+ # --- Internal Core Representation Functions (Return Raw Vectors - for Retrieval Tab) ---
257
+ # These are the ones ending with _internal, as previously defined.
258
+
259
+ def get_splade_cocondenser_representation_internal(text, tokenizer, model):
260
+ if tokenizer is None or model is None: return None
261
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
262
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
263
+ with torch.no_grad(): output = model(**inputs)
264
+ if hasattr(output, 'logits'):
265
+ splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
266
+ return splade_vector
267
+ else:
268
+ print("Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.")
269
+ return None
270
+
271
+ def get_splade_lexical_representation_internal(text, tokenizer, model):
272
+ if tokenizer is None or model is None: return None
273
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
274
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
275
+ with torch.no_grad(): output = model(**inputs)
276
+ if hasattr(output, 'logits'):
277
+ splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
278
+ vocab_size = tokenizer.vocab_size
279
+ bow_mask = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer).squeeze()
280
+ splade_vector = splade_vector * bow_mask
281
+ return splade_vector
282
+ else:
283
+ print("Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.")
284
+ return None
285
+
286
+ def get_splade_doc_representation_internal(text, tokenizer, model):
287
+ if tokenizer is None or model is None: return None
288
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
289
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
290
+ vocab_size = tokenizer.vocab_size
291
+ binary_splade_vector = create_lexical_bow_mask(inputs['input_ids'], vocab_size, tokenizer).squeeze()
292
+ return binary_splade_vector
293
+
294
+
295
+ # --- Document Indexing Function (for Retrieval Tab) ---
296
+ def index_documents(doc_model_choice):
297
+ global document_representations
298
+ if document_representations:
299
+ print("Documents already indexed. Skipping re-indexing.")
300
+ return True
301
+
302
+ tokenizer_to_use = None
303
+ model_to_use = None
304
+ representation_func_to_use = None
305
+
306
+ if doc_model_choice == "SPLADE-cocondenser-distil":
307
+ if tokenizer_splade is None or model_splade is None:
308
+ print("SPLADE-cocondenser-distil model not loaded for indexing.")
309
+ return False
310
+ tokenizer_to_use = tokenizer_splade
311
+ model_to_use = model_splade
312
+ representation_func_to_use = get_splade_cocondenser_representation_internal
313
+ elif doc_model_choice == "SPLADE-v3-Lexical":
314
+ if tokenizer_splade_lexical is None or model_splade_lexical is None:
315
+ print("SPLADE-v3-Lexical model not loaded for indexing.")
316
+ return False
317
+ tokenizer_to_use = tokenizer_splade_lexical
318
+ model_to_use = model_splade_lexical
319
+ representation_func_to_use = get_splade_lexical_representation_internal
320
+ elif doc_model_choice == "SPLADE-v3-Doc":
321
+ if tokenizer_splade_doc is None or model_splade_doc is None:
322
+ print("SPLADE-v3-Doc model not loaded for indexing.")
323
+ return False
324
+ tokenizer_to_use = tokenizer_splade_doc
325
+ model_to_use = model_splade_doc
326
+ representation_func_to_use = get_splade_doc_representation_internal
327
+ else:
328
+ print(f"Invalid model choice for document indexing: {doc_model_choice}")
329
+ return False
330
+
331
+ print(f"Indexing documents using {doc_model_choice}...")
332
+
333
+ doc_items = list(document_texts.items())
334
+
335
+ for doc_id, doc_text in tqdm(doc_items, desc="Indexing Documents"):
336
+ sparse_vector = representation_func_to_use(doc_text, tokenizer_to_use, model_to_use)
337
+ if sparse_vector is not None:
338
+ document_representations[doc_id] = sparse_vector.cpu()
339
+ else:
340
+ print(f"Warning: Failed to get representation for doc_id {doc_id}")
341
+
342
+ print(f"Finished indexing {len(document_representations)} documents.")
343
+ return True
344
+
345
+ # --- Retrieval Function (for Retrieval Tab) ---
346
+ def retrieve_documents(query_text, query_model_choice, indexed_doc_model_name, top_k=5):
347
+ if not document_representations:
348
+ return "Document index is not loaded or empty. Please ensure documents are indexed.", []
349
+
350
+ query_vector = None
351
+ query_tokenizer = None
352
+ query_model = None
353
+
354
+ if query_model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
355
+ query_tokenizer = tokenizer_splade
356
+ query_model = model_splade
357
+ query_vector = get_splade_cocondenser_representation_internal(query_text, query_tokenizer, query_model)
358
+ elif query_model_choice == "SPLADE-v3-Lexical (weighting)":
359
+ query_tokenizer = tokenizer_splade_lexical
360
+ query_model = model_splade_lexical
361
+ query_vector = get_splade_lexical_representation_internal(query_text, query_tokenizer, query_model)
362
+ elif query_model_choice == "SPLADE-v3-Doc (binary)":
363
+ query_tokenizer = tokenizer_splade_doc
364
+ query_model = model_splade_doc
365
+ query_vector = get_splade_doc_representation_internal(query_text, query_tokenizer, query_model)
366
+ else:
367
+ return "Invalid query model choice.", []
368
+
369
+ if query_vector is None:
370
+ return "Failed to get query representation. Check console for model loading errors.", []
371
+
372
+ query_vector = query_vector.cpu()
373
+
374
+ scores = {}
375
+ for doc_id, doc_vec in document_representations.items():
376
+ score = torch.dot(query_vector, doc_vec).item()
377
+ scores[doc_id] = score
378
+
379
+ sorted_scores = sorted(scores.items(), key=lambda item: item[1], reverse=True)
380
+ top_results = sorted_scores[:top_k]
381
+
382
+ formatted_output = f"Retrieval Results for Query: '{query_text}'\n"
383
+ formatted_output += f"Using Query Model: **{query_model_choice}**\n"
384
+ formatted_output += f"Documents Indexed with: **{indexed_doc_model_name}**\n\n"
385
+
386
+ if not top_results:
387
+ formatted_output += "No documents found or scored.\n"
388
+ else:
389
+ for i, (doc_id, score) in enumerate(top_results):
390
+ doc_text = document_texts.get(doc_id, "Document text not available.")
391
+ formatted_output += f"**{i+1}. Document ID: {doc_id}** (Score: {score:.4f})\n"
392
+ formatted_output += f"> {doc_text[:300]}...\n\n"
393
+
394
+ return formatted_output, top_results
395
+
396
+ # --- Unified Prediction Function for Gradio (for Retrieval Tab) ---
397
+ def predict_retrieval_gradio(query_text, query_model_choice, selected_doc_model_display_only):
398
+ formatted_output, _ = retrieve_documents(query_text, query_model_choice, initial_doc_model_for_indexing, top_k=5)
399
+ return formatted_output
400
+
401
+ # --- Initial Load and Indexing Calls ---
402
+ # This part runs once when the app starts.
403
+ load_scifact_corpus_ir_datasets() # Or load_cranfield_corpus_ir_datasets() if you switch back
404
+
405
+ if initial_doc_model_for_indexing == "SPLADE-cocondenser-distil" and model_splade is not None:
406
+ index_documents(initial_doc_model_for_indexing)
407
+ elif initial_doc_model_for_indexing == "SPLADE-v3-Lexical" and model_splade_lexical is not None:
408
+ index_documents(initial_doc_model_for_indexing)
409
+ elif initial_doc_model_for_indexing == "SPLADE-v3-Doc" and model_splade_doc is not None:
410
+ index_documents(initial_doc_model_for_indexing)
411
+ else:
412
+ print(f"Skipping document indexing: Model '{initial_doc_model_for_indexing}' failed to load or is not a valid choice for indexing.")
413
+
414
+
415
+ # --- Gradio Interface Setup with Tabs ---
416
+ with gr.Blocks(title="SPLADE Demos") as demo:
417
+ gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer & Document Retrieval")
418
+ gr.Markdown("Explore different SPLADE models and their sparse representation types, or perform document retrieval on a test collection.")
419
+
420
+ with gr.Tabs():
421
+ with gr.TabItem("Sparse Representation Explorer"):
422
+ gr.Markdown("### Explore Raw SPLADE Representations for Any Text")
423
+ gr.Interface(
424
+ fn=predict_representation_explorer,
425
+ inputs=[
426
+ gr.Radio(
427
+ [
428
+ "SPLADE-cocondenser-distil (weighting and expansion)",
429
+ "SPLADE-v3-Lexical (weighting)",
430
+ "SPLADE-v3-Doc (binary)"
431
+ ],
432
+ label="Choose Representation Model",
433
+ value="SPLADE-cocondenser-distil (weighting and expansion)"
434
+ ),
435
+ gr.Textbox(
436
+ lines=5,
437
+ label="Enter your query or document text here:",
438
+ placeholder="e.g., Why is Padua the nicest city in Italy?"
439
+ )
440
+ ],
441
+ outputs=gr.Markdown(),
442
+ allow_flagging="never",
443
+ # Don't show redundant title/description within the tab, as it's above
444
+ # Setting live=True might be slow for complex models on every keystroke
445
+ # live=True
446
+ )
447
+
448
+ with gr.TabItem("Document Retrieval Demo"):
449
+ gr.Markdown("### Retrieve Documents from SciFact Collection")
450
+ gr.Interface(
451
+ fn=predict_retrieval_gradio,
452
+ inputs=[
453
+ gr.Textbox(
454
+ lines=3,
455
+ label="Enter your query text here:",
456
+ placeholder="e.g., Does high-dose vitamin C cure cancer?"
457
+ ),
458
+ gr.Radio(
459
+ [
460
+ "SPLADE-cocondenser-distil (weighting and expansion)",
461
+ "SPLADE-v3-Lexical (weighting)",
462
+ "SPLADE-v3-Doc (binary)"
463
+ ],
464
+ label="Choose Query Representation Model",
465
+ value="SPLADE-cocondenser-distil (weighting and expansion)"
466
+ ),
467
+ gr.Radio(
468
+ [
469
+ "SPLADE-cocondenser-distil",
470
+ "SPLADE-v3-Lexical",
471
+ "SPLADE-v3-Doc"
472
+ ],
473
+ label=f"Document Index Model (Pre-indexed with: {initial_doc_model_for_indexing})",
474
+ value=initial_doc_model_for_indexing,
475
+ interactive=False # This radio is fixed for simplicity
476
+ )
477
+ ],
478
+ outputs=gr.Markdown(),
479
+ allow_flagging="never",
480
+ # live=True # retrieval is too heavy for live
481
+ )
482
+
483
  demo.launch()