import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import os
from dotenv import load_dotenv
from huggingface_hub import login
from transformers import BitsAndBytesConfig
import logging

# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

load_dotenv()

# Login to Hugging Face
hf_token = os.getenv('HF_TOKEN')
login(hf_token)

# Configuration du modèle
model_path = "mistralai/Mistral-Large-Instruct-2411"

# Détermination automatique du dtype optimal
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
logger.info(f"Using dtype: {dtype}")

# Configuration de la quantification 4-bits
logger.info("Configuring 4-bit quantization")
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=dtype,  # Utilisation du dtype optimal
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

# Initialisation du modèle
logger.info(f"Loading tokenizer from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
logger.info("Tokenizer loaded successfully")

logger.info(f"Loading model from {model_path} with 4-bit quantization")
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    quantization_config=quantization_config
)
logger.info("Model loaded successfully")

logger.info("Creating inference pipeline")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
logger.info("Inference pipeline created successfully")

def generate_response(message, temperature=0.7, max_new_tokens=256):
    try:
        logger.info(f"Generating response for message: {message[:50]}...")
        parameters = {
            "temperature": temperature,
            "max_new_tokens": max_new_tokens,
            # "do_sample": True,
            # "top_k": 50,
            # "top_p": 0.9,
            # "pad_token_id": tokenizer.pad_token_id,
            # "eos_token_id": tokenizer.eos_token_id,
            # "batch_size": 1
        }
        logger.info(f"Parameters: {parameters}")
        
        response = pipe(message, **parameters)
        logger.info("Response generated successfully")
        return response[0]['generated_text']
    except Exception as e:
        logger.error(f"Error during generation: {str(e)}")
        return f"Une erreur s'est produite : {str(e)}"

# Interface Gradio
demo = gr.Interface(
    fn=generate_response,
    inputs=[
        gr.Textbox(label="Votre message", placeholder="Entrez votre message ici..."),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Température"),
        gr.Slider(minimum=10, maximum=3000, value=256, step=10, label="Nombre de tokens")
    ],
    outputs=gr.Textbox(label="Réponse"),
    title="Chat avec Sacha-Mistral",
    description="Un assistant conversationnel en français basé sur le modèle Sacha-Mistral"
)

if __name__ == "__main__":
    logger.info("Starting Gradio interface")
    demo.launch(share=True)