File size: 810 Bytes
9decc9d aac811b 3ece550 aac811b 4d029c2 3ece550 aac811b 880f334 88ca48d 4d029c2 88ca48d aac811b 88ca48d 4d029c2 88ca48d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from models.zero_shot import ZeroShotModel
from models.few_shot import FewShotModel
zero_shot_model = ZeroShotModel()
few_shot_model = FewShotModel()
def get_fine_tuned_model():
from models.fine_tuned import FineTunedModel
return FineTunedModel()
def predict(text: str, model_type: str = "zero-shot") -> str:
if model_type == "few-shot":
results = few_shot_model.predict(text)
title = "Few-Shot"
elif model_type == "fine-tuned":
results = get_fine_tuned_model().predict(text)
title = "Fine-Tuned"
else:
results = zero_shot_model.predict(text)
title = "Zero-Shot"
output = f"### Résultat de la classification ({title}) :\n\n"
for label, score in results:
output += f"- **{label}** : {score * 100:.1f}%\n"
return output |