File size: 597 Bytes
39dcbb5 21b5c94 39dcbb5 21b5c94 39dcbb5 402868e 21b5c94 402868e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from models.base import BaseModel
from transformers import pipeline
class FewShotModel(BaseModel):
def __init__(self):
# On utilise un modèle préentraîné pour la classification de texte
self.classifier = pipeline("text-classification", model="textattack/roberta-base-rotten-tomatoes")
def predict(self, text: str) -> list[tuple[str, float]]:
result = self.classifier(text, truncation=True)[0]
label = result["label"].lower()
score = result["score"]
label = "non-toxique" if "pos" in label else "toxique"
return [(label, score)] |