A815 commited on
Commit
e3355b1
·
1 Parent(s): 84890a4
Files changed (2) hide show
  1. app.py +62 -0
  2. nlp4web-codebase +0 -1
app.py CHANGED
@@ -332,6 +332,68 @@ bm25_index = BM25Index.build_from_documents(
332
  bm25_index.save("output/bm25_index")
333
 
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  from scipy.sparse._csc import csc_matrix
337
 
 
332
  bm25_index.save("output/bm25_index")
333
 
334
 
335
+ plots_b: Dict[str, List[float]] = {
336
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
337
+ "Y": []
338
+ }
339
+ plots_k1: Dict[str, List[float]] = {
340
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
341
+ "Y": []
342
+ }
343
+
344
+ ## YOUR_CODE_STARTS_HERE
345
+ # Two steps should be involved:
346
+ # Step 1. Fix k1 value to the default one 0.9,
347
+ # go through all the candidate b values (0, 0.1, ..., 1.0),
348
+ # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map;
349
+ # Step 2. Fix b to the best one in step 1. and do the same for k1.
350
+
351
+ # Hint (on using the pre-requisite code):
352
+ # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code);
353
+ # - One can build bm25_index with `BM25Index.build_from_documents`;
354
+ # - One can use BM25Retriever to load the index and perform retrieval on the dev queries
355
+ # (dev queries can be obtained via sciq.get_split_queries(Split.dev))
356
+
357
+
358
+ for b in plots_b["X"]:
359
+ bm25_index = BM25Index.build_from_documents(
360
+ documents=iter(sciq.corpus),
361
+ ndocs=12160,
362
+ show_progress_bar=False,
363
+ k1=0.9,
364
+ b=b
365
+ )
366
+ bm25_index.save("output/bm25_index")
367
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
368
+ rankings = {}
369
+ for query in sciq.get_split_queries(Split.dev):
370
+ ranking = bm25_retriever.retrieve(query=query.text)
371
+ rankings[query.query_id] = ranking
372
+
373
+ k1_b_map = evaluate_map(rankings, split=Split.dev)
374
+ plots_b["Y"].append(k1_b_map)
375
+
376
+ best_b = plots_b["X"][np.argmax(plots_b["Y"])]
377
+
378
+ for k1 in plots_k1["X"]:
379
+ bm25_index = BM25Index.build_from_documents(
380
+ documents=iter(sciq.corpus),
381
+ ndocs=12160,
382
+ show_progress_bar=False,
383
+ k1=k1,
384
+ b=best_b
385
+ )
386
+ bm25_index.save("output/bm25_index")
387
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
388
+ rankings = {}
389
+ for query in sciq.get_split_queries(Split.dev):
390
+ ranking = bm25_retriever.retrieve(query=query.text)
391
+ rankings[query.query_id] = ranking
392
+
393
+ k1_b_map = evaluate_map(rankings, split=Split.dev)
394
+ plots_k1["Y"].append(k1_b_map)
395
+
396
+
397
 
398
  from scipy.sparse._csc import csc_matrix
399
 
nlp4web-codebase DELETED
@@ -1 +0,0 @@
1
- Subproject commit 83f9afbbf7e372c116fdd04997a96449007f861f