cognisafe-backend / app /ml_models /classifier_loader.py
zyriean's picture
add app
d68e65a verified
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from app.ml_models.classifier_path_loader import ClassifierPathLoader
import logging
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification, pipeline
logger = logging.getLogger(__name__)
class ClassifierLoader:
def __init__(self, model_name: str):
self.model_name = model_name
self.model = None
self.tokenizer = None
path_loader = ClassifierPathLoader()
path_loader.set_model(self.model_name)
self.model_path = path_loader.get_model_path()
# If model doesn't exist, download it
if not self.model_path.exists():
model_name = "unitary/toxic-bert"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer.save_pretrained(self.model_path)
model.save_pretrained(self.model_path)
def load_model(self):
if self.model is None:
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_path
)
logger.info("[βœ…] Model loaded successfully.")
return self.model
def load_tokenizer(self):
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
logger.info("[βœ…] Tokenizer loaded successfully.")
return self.tokenizer