#!/usr/bin/env python
# coding=utf-8

"""
Installation script for Hugging Face Space setup.
This script ensures all dependencies are installed correctly 
during the Space build process.
"""

import os
import sys
import subprocess
import logging
import traceback
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

def run_command(cmd, description=""):
    """Run a shell command and log the output."""
    logger.info(f"Running: {description if description else cmd}")
    try:
        process = subprocess.run(
            cmd, 
            shell=True, 
            check=True,
            stdout=subprocess.PIPE, 
            stderr=subprocess.PIPE,
            text=True
        )
        logger.info(f"Command output: {process.stdout}")
        return True
    except subprocess.CalledProcessError as e:
        logger.error(f"Command failed with exit code {e.returncode}")
        logger.error(f"Error output: {e.stderr}")
        return False

def install_dependencies():
    """Install all required dependencies in the correct order."""
    current_dir = Path(__file__).parent
    req_path = current_dir / "requirements.txt"
    
    if not req_path.exists():
        logger.error(f"Requirements file not found: {req_path}")
        return False
    
    try:
        # Step 1: Upgrade pip
        run_command(f"{sys.executable} -m pip install --upgrade pip", "Upgrading pip")
        
        # Step 2: Install direct torch version for CUDA compatibility
        run_command(
            f"{sys.executable} -m pip install torch>=2.0.0,<2.2.0 --extra-index-url https://download.pytorch.org/whl/cu118",
            "Installing PyTorch with CUDA support"
        )
        
        # Step 3: Install base dependencies
        run_command(
            f"{sys.executable} -m pip install transformers accelerate bitsandbytes peft einops",
            "Installing ML dependencies"
        )
        
        # Step 4: Install unsloth separately
        run_command(
            f"{sys.executable} -m pip install unsloth>=2024.3",
            "Installing Unsloth"
        )
        
        # Step 5: Install all remaining requirements
        run_command(
            f"{sys.executable} -m pip install -r {req_path}",
            "Installing all requirements"
        )
        
        # Verify critical packages
        import_check = verify_imports()
        if not import_check:
            logger.error("Failed to verify critical packages")
            return False
        
        logger.info("All dependencies installed successfully!")
        return True
        
    except Exception as e:
        logger.error(f"Error installing dependencies: {str(e)}")
        traceback.print_exc()
        return False

def verify_imports():
    """Verify that critical packages can be imported."""
    critical_packages = [
        "torch", "transformers", "unsloth", "peft", 
        "gradio", "accelerate", "bitsandbytes"
    ]
    
    success = True
    for package in critical_packages:
        try:
            module = __import__(package)
            version = getattr(module, "__version__", "unknown")
            logger.info(f"Successfully imported {package} (version: {version})")
        except ImportError:
            logger.error(f"CRITICAL: Failed to import {package}")
            success = False
        except Exception as e:
            logger.error(f"Error verifying {package}: {str(e)}")
            success = False
    
    # Check CUDA
    try:
        import torch
        cuda_available = torch.cuda.is_available()
        device_count = torch.cuda.device_count() if cuda_available else 0
        if cuda_available:
            device_name = torch.cuda.get_device_name(0)
            logger.info(f"CUDA available - Devices: {device_count}, Name: {device_name}")
        else:
            logger.warning(f"CUDA not available - This might affect performance")
    except Exception as e:
        logger.error(f"Error checking CUDA: {str(e)}")
    
    return success

def main():
    logger.info("Starting installation for Phi-4 Unsloth Training Space")
    
    try:
        # Install dependencies
        if not install_dependencies():
            logger.error("Failed to install dependencies")
            sys.exit(1)
        
        # Create marker file to show successful installation
        with open("INSTALL_SUCCESS.txt", "w") as f:
            f.write("Installation completed successfully")
        
        logger.info("Installation completed successfully")
        return 0
        
    except Exception as e:
        logger.error(f"Installation failed with error: {str(e)}")
        traceback.print_exc()
        return 1

if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code)