File size: 597 Bytes
39dcbb5
21b5c94
39dcbb5
 
 
21b5c94
 
39dcbb5
402868e
21b5c94
 
402868e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from models.base import BaseModel
from transformers import pipeline

class FewShotModel(BaseModel):
    def __init__(self):
        # On utilise un modèle préentraîné pour la classification de texte
        self.classifier = pipeline("text-classification", model="textattack/roberta-base-rotten-tomatoes")

    def predict(self, text: str) -> list[tuple[str, float]]:
        result = self.classifier(text, truncation=True)[0]
        label = result["label"].lower()
        score = result["score"]
        label = "non-toxique" if "pos" in label else "toxique"
        return [(label, score)]