Tbruand
commited on
Commit
·
9decc9d
1
Parent(s):
c592603
feat(models): ajoute un modèle de classification zero-shot basé sur bart-large-mnli
Browse files- app/handler.py +3 -10
- models/zero_shot.py +11 -0
app/handler.py
CHANGED
@@ -1,13 +1,6 @@
|
|
1 |
-
from models.
|
2 |
|
3 |
-
|
4 |
-
class DummyModel(BaseModel):
|
5 |
-
def predict(self, text):
|
6 |
-
if "stupide" in text.lower():
|
7 |
-
return "toxique"
|
8 |
-
return "non-toxique"
|
9 |
-
|
10 |
-
dummy = DummyModel()
|
11 |
|
12 |
def predict(text: str) -> str:
|
13 |
-
return
|
|
|
1 |
+
from models.zero_shot import ZeroShotModel
|
2 |
|
3 |
+
model = ZeroShotModel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
def predict(text: str) -> str:
|
6 |
+
return model.predict(text)
|
models/zero_shot.py
CHANGED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
from models.base import BaseModel
|
3 |
+
|
4 |
+
class ZeroShotModel(BaseModel):
|
5 |
+
def __init__(self):
|
6 |
+
self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
7 |
+
|
8 |
+
def predict(self, text: str) -> str:
|
9 |
+
labels = ["toxique", "non-toxique"]
|
10 |
+
result = self.classifier(text, candidate_labels=labels)
|
11 |
+
return result["labels"][0]
|