Tbruand
commited on
Commit
·
21b5c94
1
Parent(s):
4af0704
feat(models): ajout d'un modèle FewShot basé sur transformers avec Roberta
Browse files- models/few_shot.py +7 -13
models/few_shot.py
CHANGED
@@ -1,19 +1,13 @@
|
|
1 |
from models.base import BaseModel
|
|
|
2 |
|
3 |
class FewShotModel(BaseModel):
|
4 |
def __init__(self):
|
5 |
-
|
6 |
-
|
7 |
-
("Je vais te tuer", "toxique"),
|
8 |
-
("Merci pour ton aide", "non-toxique"),
|
9 |
-
("J'apprécie ton soutien", "non-toxique")
|
10 |
-
]
|
11 |
|
12 |
def predict(self, text: str) -> str:
|
13 |
-
|
14 |
-
|
15 |
-
#
|
16 |
-
|
17 |
-
non_toxic_score = sum(any(word in example.lower() for word in text.split()) for example, label in self.examples if label == "non-toxique")
|
18 |
-
|
19 |
-
return "toxique" if toxic_score >= non_toxic_score else "non-toxique"
|
|
|
1 |
from models.base import BaseModel
|
2 |
+
from transformers import pipeline
|
3 |
|
4 |
class FewShotModel(BaseModel):
|
5 |
def __init__(self):
|
6 |
+
# On utilise un modèle préentraîné pour la classification de texte
|
7 |
+
self.classifier = pipeline("text-classification", model="textattack/roberta-base-rotten-tomatoes")
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def predict(self, text: str) -> str:
|
10 |
+
result = self.classifier(text, truncation=True)[0]
|
11 |
+
label = result["label"].lower()
|
12 |
+
# Conversion binaire "positive"/"negative" en "non-toxique"/"toxique"
|
13 |
+
return "non-toxique" if "pos" in label else "toxique"
|
|
|
|
|
|