Akeb0n0 commited on
Commit
431d797
·
verified ·
1 Parent(s): 2b4bdc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -4,13 +4,19 @@ import torch.serialization
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
 
6
  @st.cache_resource
7
- def load_model():
8
- checkpoint = torch.load('TinyBERT_cls_model.pt',
 
 
 
 
 
 
9
  map_location='cpu',
10
  weights_only=False)
11
 
12
  model = AutoModelForSequenceClassification.from_pretrained(
13
- "huawei-noah/TinyBERT_General_4L_312D",
14
  num_labels=len(checkpoint['idx_to_category'])
15
  )
16
  model.load_state_dict(checkpoint['model_state_dict'])
@@ -20,7 +26,7 @@ def load_model():
20
 
21
  return model, tokenizer, idx_to_category
22
 
23
- def predict(title, abstract, model, tokenizer, idx_to_category, threshold=0.95):
24
  text = f"{title} /n {abstract}" if abstract else title
25
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
26
 
 
4
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
 
6
  @st.cache_resource
7
+ def load_model():
8
+ #trained_model = 'TinyBERT_cls_model.pt'
9
+ #base_model = 'huawei-noah/TinyBERT_General_4L_312D'
10
+
11
+ trained_model = 'distilbert-base_cls_model.pt'
12
+ base_model = 'distilbert-base-uncased'
13
+
14
+ checkpoint = torch.load(trained_model,
15
  map_location='cpu',
16
  weights_only=False)
17
 
18
  model = AutoModelForSequenceClassification.from_pretrained(
19
+ base_model,
20
  num_labels=len(checkpoint['idx_to_category'])
21
  )
22
  model.load_state_dict(checkpoint['model_state_dict'])
 
26
 
27
  return model, tokenizer, idx_to_category
28
 
29
+ def predict(title, abstract, model, tokenizer, idx_to_category, threshold=0.97):
30
  text = f"{title} /n {abstract}" if abstract else title
31
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
32