import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr from fastapi import FastAPI from pydantic import BaseModel import uvicorn import threading # --- 1. Charger le modèle --- model_name_or_path = "facebook/MobileLLM-Pro" # modèle gated tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_auth_token=True) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, use_auth_token=True, torch_dtype=torch.float16, device_map="auto" ) model.eval() # --- 2. Fonction de prédiction --- def predict(text, max_length=128): inputs = tokenizer(text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=max_length) return tokenizer.decode(outputs[0], skip_special_tokens=True) # --- 3. Interface Gradio --- def launch_gradio(): iface = gr.Interface( fn=predict, inputs=[gr.Textbox(lines=5, placeholder="Écris ton texte ici...")], outputs=[gr.Textbox(label="Réponse du modèle")], title="MobileLLM-Pro Chat", description="Interface Gradio pour MobileLLM-Pro" ) iface.launch(server_name="0.0.0.0", server_port=7860, share=True) # --- 4. API FastAPI --- app = FastAPI() class RequestText(BaseModel): text: str max_length: int = 128 @app.post("/predict") def api_predict(request: RequestText): return {"response": predict(request.text, request.max_length)} # --- 5. Lancer Gradio dans un thread pour pouvoir aussi lancer FastAPI --- if __name__ == "__main__": threading.Thread(target=launch_gradio, daemon=True).start() uvicorn.run(app, host="0.0.0.0", port=8000)