Tbruand commited on
Commit
9decc9d
·
1 Parent(s): c592603

feat(models): ajoute un modèle de classification zero-shot basé sur bart-large-mnli

Browse files
Files changed (2) hide show
  1. app/handler.py +3 -10
  2. models/zero_shot.py +11 -0
app/handler.py CHANGED
@@ -1,13 +1,6 @@
1
- from models.base import BaseModel
2
 
3
- # Exemple de modèle minimal
4
- class DummyModel(BaseModel):
5
- def predict(self, text):
6
- if "stupide" in text.lower():
7
- return "toxique"
8
- return "non-toxique"
9
-
10
- dummy = DummyModel()
11
 
12
  def predict(text: str) -> str:
13
- return dummy.predict(text)
 
1
+ from models.zero_shot import ZeroShotModel
2
 
3
+ model = ZeroShotModel()
 
 
 
 
 
 
 
4
 
5
  def predict(text: str) -> str:
6
+ return model.predict(text)
models/zero_shot.py CHANGED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from models.base import BaseModel
3
+
4
+ class ZeroShotModel(BaseModel):
5
+ def __init__(self):
6
+ self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
7
+
8
+ def predict(self, text: str) -> str:
9
+ labels = ["toxique", "non-toxique"]
10
+ result = self.classifier(text, candidate_labels=labels)
11
+ return result["labels"][0]