File size: 810 Bytes
9decc9d
aac811b
3ece550
aac811b
 
4d029c2
 
 
 
3ece550
aac811b
 
880f334
88ca48d
 
4d029c2
88ca48d
aac811b
 
88ca48d
 
 
 
4d029c2
88ca48d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from models.zero_shot import ZeroShotModel
from models.few_shot import FewShotModel

zero_shot_model = ZeroShotModel()
few_shot_model = FewShotModel()

def get_fine_tuned_model():
    from models.fine_tuned import FineTunedModel
    return FineTunedModel()

def predict(text: str, model_type: str = "zero-shot") -> str:
    if model_type == "few-shot":
        results = few_shot_model.predict(text)
        title = "Few-Shot"
    elif model_type == "fine-tuned":
        results = get_fine_tuned_model().predict(text)
        title = "Fine-Tuned"
    else:
        results = zero_shot_model.predict(text)
        title = "Zero-Shot"

    output = f"### Résultat de la classification ({title}) :\n\n"
    for label, score in results:
        output += f"- **{label}** : {score * 100:.1f}%\n"
    return output