ajitrajasekharan commited on
Commit
0d25a6d
·
1 Parent(s): 7c8e7ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -54,8 +54,8 @@ def get_all_predictions(text_sentence, model_name,top_clean=5):
54
 
55
  with torch.no_grad():
56
  predict = bert_model(input_ids)[0]
57
- bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*10).indices.tolist(), top_clean)
58
- cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*10).indices.tolist(), top_clean)
59
 
60
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
61
  return {'Input sentence':text_sentence,'Tokenized text': tokenized_text, 'results_count':top_k,'Model':model_name,'Masked position': bert,'[CLS]':cls}
 
54
 
55
  with torch.no_grad():
56
  predict = bert_model(input_ids)[0]
57
+ bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(top_k*2).indices.tolist(), top_clean)
58
+ cls = decode(bert_tokenizer, predict[0, 0, :].topk(top_k*2).indices.tolist(), top_clean)
59
 
60
  if ("[MASK]" in text_sentence or "<mask>" in text_sentence):
61
  return {'Input sentence':text_sentence,'Tokenized text': tokenized_text, 'results_count':top_k,'Model':model_name,'Masked position': bert,'[CLS]':cls}