Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForSequenceClassification | |
from transformers import BertTokenizerFast | |
class BertWrapper(object): | |
MODELS_DIR: str = "new_models/" | |
MODEL_NAME: str = "model" | |
TOKENIZER: str = "tokenizer" | |
def __init__(self) -> None: | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
BertWrapper.MODELS_DIR + BertWrapper.MODEL_NAME, torchscript=True | |
) | |
self.tokenizer = BertTokenizerFast.from_pretrained( | |
"blanchefort/rubert-base-cased-sentiment" | |
) | |
self.id2label: dict[int, str] = {0: "__label__positive", 1: "__label__negative"} | |
def __call__(self, text: str) -> str: | |
max_input_length = ( | |
self.model.config.max_position_embeddings | |
) # 512 for this model | |
inputs = self.tokenizer( | |
text, | |
max_length=max_input_length, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
) | |
outputs = self.model( | |
**inputs, return_dict=True | |
) # output is logits for huggingfcae transformers | |
predicted = torch.nn.functional.softmax(outputs.logits, dim=1) | |
predicted_id = torch.argmax(predicted, dim=1).numpy()[0] | |
return self.id2label[predicted_id], predicted[0][predicted_id] | |