Tbruand commited on
Commit
0005ea7
·
1 Parent(s): 52e008b

feat(models): retourne tous les labels et scores dans ZeroShotModel.predict

Browse files
Files changed (1) hide show
  1. models/zero_shot.py +2 -4
models/zero_shot.py CHANGED
@@ -5,9 +5,7 @@ class ZeroShotModel(BaseModel):
5
  def __init__(self):
6
  self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
7
 
8
- def predict(self, text: str) -> tuple[str, float]:
9
  labels = ["toxique", "non-toxique"]
10
  result = self.classifier(text, candidate_labels=labels)
11
- label = result["labels"][0]
12
- score = result["scores"][0]
13
- return label, score
 
5
  def __init__(self):
6
  self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
7
 
8
+ def predict(self, text: str) -> list[tuple[str, float]]:
9
  labels = ["toxique", "non-toxique"]
10
  result = self.classifier(text, candidate_labels=labels)
11
+ return list(zip(result["labels"], result["scores"]))