File size: 452 Bytes
9decc9d
 
 
 
 
 
 
0005ea7
9decc9d
 
0005ea7
1
2
3
4
5
6
7
8
9
10
11
from transformers import pipeline
from models.base import BaseModel

class ZeroShotModel(BaseModel):
    def __init__(self):
        self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

    def predict(self, text: str) -> list[tuple[str, float]]:
        labels = ["toxique", "non-toxique"]
        result = self.classifier(text, candidate_labels=labels)
        return list(zip(result["labels"], result["scores"]))