File size: 4,678 Bytes
998fb7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab5643c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import login
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from quanto import quantize_model

# === [AUTHENTICATION] ===
hf_token = os.getenv("hf_token")
if hf_token is None:
    raise ValueError("Please set HF_TOKEN environment variable with your Hugging Face token")
login(token=hf_token)

# === [TRANSLATOR] ===
translator = pipeline("translation", model="facebook/nllb-200-distilled-600M")

# === [LOAD & QUANTIZE MODEL] ===
model_name = "ContactDoctor/Bio-Medical-Llama-3-2-1B-CoT-012025"
tokenizer = AutoTokenizer.from_pretrained(model_name)

print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

print("Quantizing model...")
quantized_model = quantize_model(model, bits=4)

print("Initializing pipeline...")
text_gen_pipeline = pipeline(
    "text-generation",
    model=quantized_model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto"
)

# === [SYSTEM MESSAGE] ===
system_message = {
    "role": "system",
    "content": (
        "You are a helpful, respectful, and knowledgeable medical assistant developed by the AI team at AfriAI Solutions, Senegal. "
        "Provide brief, clear definitions when answering medical questions. After giving a concise response, ask the user if they would like more information about symptoms, causes, or treatments. "
        "Always encourage users to consult healthcare professionals for personalized advice."
    )
}

messages = [system_message]
max_history = 10

salutations = ["bonjour", "salut", "bonsoir", "coucou"]
remerciements = ["merci", "je vous remercie", "thanks"]
au_revoir = ["au revoir", "à bientôt", "bye", "bonne journée", "à la prochaine"]

def detect_smalltalk(user_input):
    lower_input = user_input.lower().strip()
    if any(phrase in lower_input for phrase in salutations):
        return "Bonjour ! Comment puis-je vous aider aujourd'hui ?", True
    if any(phrase in lower_input for phrase in remerciements):
        return "Avec plaisir ! Souhaitez-vous poser une autre question médicale ?", True
    if any(phrase in lower_input for phrase in au_revoir):
        return "Au revoir ! Prenez soin de votre santé et n'hésitez pas à revenir si besoin.", True
    return "", False

def medical_chatbot(user_input):
    global messages

    smalltalk_response, handled = detect_smalltalk(user_input)
    if handled:
        return smalltalk_response

    translated = translator(user_input, src_lang="fra_Latn", tgt_lang="eng_Latn")[0]['translation_text']
    
    messages.append({"role": "user", "content": translated})
    if len(messages) > max_history * 2:
        messages = [system_message] + messages[-max_history * 2:]

    prompt = text_gen_pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

    response = text_gen_pipeline(
        prompt,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.4,
        top_k=150,
        top_p=0.75,
        eos_token_id=[
            text_gen_pipeline.tokenizer.eos_token_id,
            text_gen_pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
    )

    output = response[0]['generated_text'][len(prompt):].strip()
    translated_back = translator(output, src_lang="eng_Latn", tgt_lang="fra_Latn")[0]['translation_text']
    
    messages.append({"role": "assistant", "content": translated_back})
    return translated_back

# === [LOGO LOAD] ===
logo = Image.open("AfriAI Solutions.jpg")

# === [GRADIO UI] ===
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo")) as demo:
    with gr.Row():
        gr.Image(value=logo, show_label=False, show_download_button=False, interactive=False, height=150)

    gr.Markdown("""
    # 🤖 Chatbot Médical AfriAI Solutions
    **Posez votre question médicale en français.**  
    Le chatbot vous répondra brièvement et avec bienveillance, puis vous demandera si vous souhaitez plus de détails.
    """, elem_id="title")

    chatbot = gr.Chatbot(label="Chat avec le Médecin Virtuel")
    msg = gr.Textbox(label="Votre question", placeholder="Exemple : Quels sont les symptômes du paludisme ?")
    clear = gr.Button("Effacer la conversation", variant="secondary")

    def respond(message, history):
        response = medical_chatbot(message)
        history = history or []
        history.append((message, response))
        return "", history

    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: ("", []), None, [msg, chatbot])

demo.launch()