#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Fine-tuning script for DeepSeek-R1-Distill-Qwen-14B-bnb-4bit using unsloth
RESEARCH TRAINING PHASE ONLY - No output generation
WORKS WITH PRE-TOKENIZED DATASET - No re-tokenization
"""

import os
import json
import logging
import argparse
import numpy as np
from dotenv import load_dotenv
import torch
from datasets import load_dataset
import transformers
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig
from transformers.data.data_collator import DataCollatorMixin
from peft import LoraConfig
from unsloth import FastLanguageModel

# Disable flash attention globally
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"

# Check if tensorboard is available
try:
    import tensorboard
    TENSORBOARD_AVAILABLE = True
except ImportError:
    TENSORBOARD_AVAILABLE = False
    print("Tensorboard not available. Will skip tensorboard logging.")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler("training.log")
    ]
)
logger = logging.getLogger(__name__)

# Default dataset path - use the correct path with username
DEFAULT_DATASET = "George-API/phi4-cognitive-dataset"

def load_config(config_path):
    """Load the transformers config from JSON file"""
    logger.info(f"Loading config from {config_path}")
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

def load_and_prepare_dataset(dataset_name, config):
    """
    Load and prepare the dataset for fine-tuning.
    Sort entries by prompt_number as required.
    NO TOKENIZATION - DATASET IS ALREADY TOKENIZED
    """
    # Use the default dataset path if no specific path is provided
    if dataset_name == "phi4-cognitive-dataset":
        dataset_name = DEFAULT_DATASET
        
    logger.info(f"Loading dataset: {dataset_name}")
    
    try:
        # Load dataset
        dataset = load_dataset(dataset_name)
        
        # Extract the split we want to use (usually 'train')
        if 'train' in dataset:
            dataset = dataset['train']
        
        # Get the dataset config
        dataset_config = config.get("dataset_config", {})
        sort_field = dataset_config.get("sort_by_field", "prompt_number")
        sort_direction = dataset_config.get("sort_direction", "ascending")
        
        # Sort the dataset by prompt_number
        logger.info(f"Sorting dataset by {sort_field} in {sort_direction} order")
        if sort_direction == "ascending":
            dataset = dataset.sort(sort_field)
        else:
            dataset = dataset.sort(sort_field, reverse=True)
        
        # Add shuffle with fixed seed if specified
        if "shuffle_seed" in dataset_config:
            shuffle_seed = dataset_config.get("shuffle_seed")
            logger.info(f"Shuffling dataset with seed {shuffle_seed}")
            dataset = dataset.shuffle(seed=shuffle_seed)
        
        # Print dataset structure for debugging
        logger.info(f"Dataset loaded with {len(dataset)} entries")
        logger.info(f"Dataset columns: {dataset.column_names}")
        
        # Print a sample entry to understand structure
        if len(dataset) > 0:
            sample = dataset[0]
            logger.info(f"Sample entry structure: {list(sample.keys())}")
            if 'conversations' in sample:
                logger.info(f"Sample conversations structure: {sample['conversations'][:1]}")
                
        return dataset
    
    except Exception as e:
        logger.error(f"Error loading dataset: {str(e)}")
        logger.info("Available datasets in the Hub:")
        # Print a more helpful error message
        print(f"Failed to load dataset: {dataset_name}")
        print(f"Make sure the dataset exists and is accessible.")
        print(f"If it's a private dataset, ensure your HF_TOKEN has access to it.")
        raise

def tokenize_string(text, tokenizer):
    """Tokenize a string using the provided tokenizer"""
    if not text:
        return []
    
    # Tokenize the text
    tokens = tokenizer.encode(text, add_special_tokens=False)
    return tokens

# Data collator for pre-tokenized dataset
class PreTokenizedCollator(DataCollatorMixin):
    """
    Data collator for pre-tokenized datasets.
    Expects input_ids and labels already tokenized.
    """
    def __init__(self, pad_token_id=0, tokenizer=None):
        self.pad_token_id = pad_token_id
        self.tokenizer = tokenizer  # Keep a reference to the tokenizer for string conversion
        
    def __call__(self, features):
        # Print a sample feature to understand structure
        if len(features) > 0:
            logger.info(f"Sample feature keys: {list(features[0].keys())}")
            
        # Extract input_ids from conversations if needed
        processed_features = []
        for feature in features:
            # If input_ids is not directly available, try to extract from conversations
            if 'input_ids' not in feature and 'conversations' in feature:
                # Extract from conversations based on your dataset structure
                conversations = feature['conversations']
                
                # Debug the conversations structure
                logger.info(f"Conversations type: {type(conversations)}")
                if isinstance(conversations, list) and len(conversations) > 0:
                    logger.info(f"First conversation type: {type(conversations[0])}")
                    logger.info(f"First conversation: {conversations[0]}")
                
                # Try different approaches to extract input_ids
                if isinstance(conversations, list) and len(conversations) > 0:
                    # Case 1: If conversations is a list of dicts with 'content' field
                    if isinstance(conversations[0], dict) and 'content' in conversations[0]:
                        content = conversations[0]['content']
                        logger.info(f"Found content field: {type(content)}")
                        
                        # If content is a string, tokenize it
                        if isinstance(content, str) and self.tokenizer:
                            logger.info(f"Tokenizing string content: {content[:50]}...")
                            feature['input_ids'] = self.tokenizer.encode(content, add_special_tokens=False)
                        # If content is already a list of integers, use it directly
                        elif isinstance(content, list) and all(isinstance(x, int) for x in content):
                            feature['input_ids'] = content
                        # If content is already tokenized in some other format
                        else:
                            logger.warning(f"Unexpected content format: {type(content)}")
                            
                    # Case 2: If conversations is a list of dicts with 'input_ids' field
                    elif isinstance(conversations[0], dict) and 'input_ids' in conversations[0]:
                        feature['input_ids'] = conversations[0]['input_ids']
                    
                    # Case 3: If conversations itself contains the input_ids
                    elif all(isinstance(x, int) for x in conversations):
                        feature['input_ids'] = conversations
                    
                    # Case 4: If conversations is a list of strings
                    elif all(isinstance(x, str) for x in conversations) and self.tokenizer:
                        # Join all strings and tokenize
                        full_text = " ".join(conversations)
                        feature['input_ids'] = self.tokenizer.encode(full_text, add_special_tokens=False)
            
            # Ensure input_ids is a list of integers
            if 'input_ids' in feature:
                # If input_ids is a string, tokenize it
                if isinstance(feature['input_ids'], str) and self.tokenizer:
                    logger.info(f"Converting string input_ids to tokens: {feature['input_ids'][:50]}...")
                    feature['input_ids'] = self.tokenizer.encode(feature['input_ids'], add_special_tokens=False)
                # If input_ids is not a list, convert it
                elif not isinstance(feature['input_ids'], list):
                    try:
                        feature['input_ids'] = list(feature['input_ids'])
                    except:
                        logger.error(f"Could not convert input_ids to list: {type(feature['input_ids'])}")
            
            processed_features.append(feature)
            
        # If we still don't have input_ids, log an error
        if len(processed_features) > 0 and 'input_ids' not in processed_features[0]:
            logger.error(f"Could not find input_ids in features. Available keys: {list(processed_features[0].keys())}")
            if 'conversations' in processed_features[0]:
                logger.error(f"Conversations structure: {processed_features[0]['conversations'][:1]}")
            raise ValueError("Could not find input_ids in dataset. Please check dataset structure.")
            
        # Determine max length in this batch
        batch_max_len = max(len(x["input_ids"]) for x in processed_features)
        
        # Initialize batch tensors
        batch = {
            "input_ids": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * self.pad_token_id,
            "attention_mask": torch.zeros((len(processed_features), batch_max_len), dtype=torch.long),
            "labels": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * -100  # -100 is ignored in loss
        }
        
        # Fill batch tensors
        for i, feature in enumerate(processed_features):
            input_ids = feature["input_ids"]
            seq_len = len(input_ids)
            
            # Convert to tensor if it's a list
            if isinstance(input_ids, list):
                input_ids = torch.tensor(input_ids, dtype=torch.long)
                
            # Copy data to batch tensors
            batch["input_ids"][i, :seq_len] = input_ids
            batch["attention_mask"][i, :seq_len] = 1
            
            # If there are labels, use them, otherwise use input_ids
            if "labels" in feature:
                labels = feature["labels"]
                if isinstance(labels, list):
                    labels = torch.tensor(labels, dtype=torch.long)
                batch["labels"][i, :len(labels)] = labels
            else:
                batch["labels"][i, :seq_len] = input_ids
        
        return batch

def create_training_marker(output_dir):
    """Create a marker file to indicate training is active"""
    # Create in current directory for app.py to find
    with open("TRAINING_ACTIVE", "w") as f:
        f.write(f"Training active in {output_dir}")
    
    # Also create in output directory
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "RESEARCH_TRAINING_ONLY"), "w") as f:
        f.write("This model is for research training only. No interactive outputs.")

def remove_training_marker():
    """Remove the training marker file"""
    if os.path.exists("TRAINING_ACTIVE"):
        os.remove("TRAINING_ACTIVE")
        logger.info("Removed training active marker")

def load_model_safely(model_name, max_seq_length, dtype=None):
    """
    Load the model in a safe way that works with Qwen models
    by trying different loading strategies.
    """
    try:
        logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
        # First try the standard unsloth loading
        try:
            # Try loading with unsloth but without the problematic parameter
            logger.info("Loading model with flash attention DISABLED")
            model, tokenizer = FastLanguageModel.from_pretrained(
                model_name=model_name,
                max_seq_length=max_seq_length,
                dtype=dtype,
                load_in_4bit=True,  # This should work for already quantized models
                use_flash_attention=False  # Explicitly disable flash attention
            )
            logger.info("Model loaded successfully with unsloth with 4-bit quantization and flash attention disabled")
            return model, tokenizer
            
        except TypeError as e:
            # If we get a TypeError about unexpected keyword arguments
            if "unexpected keyword argument" in str(e):
                logger.warning(f"Unsloth loading error with 4-bit: {e}")
                logger.info("Trying alternative loading method for Qwen model...")
                
                # Try loading with different parameters for Qwen model
                model, tokenizer = FastLanguageModel.from_pretrained(
                    model_name=model_name,
                    max_seq_length=max_seq_length,
                    dtype=dtype,
                    use_flash_attention=False,  # Explicitly disable flash attention
                )
                logger.info("Model loaded successfully with unsloth using alternative method")
                return model, tokenizer
            else:
                # Re-raise if it's a different type error
                raise
                
    except Exception as e:
        # Fallback to standard loading if unsloth methods fail
        logger.warning(f"Unsloth loading failed: {e}")
        logger.info("Falling back to standard Hugging Face loading...")
        
        # Disable flash attention in transformers config
        config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        if hasattr(config, "use_flash_attention"):
            config.use_flash_attention = False
            logger.info("Disabled flash attention in model config")
        
        tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            config=config,
            device_map="auto",
            torch_dtype=dtype or torch.float16,
            load_in_4bit=True
        )
        logger.info("Model loaded successfully with standard HF loading and flash attention disabled")
        return model, tokenizer

def train(config_path, dataset_name, output_dir):
    """Main training function - RESEARCH TRAINING PHASE ONLY"""
    # Load environment variables
    load_dotenv()
    config = load_config(config_path)
    
    # Extract configs
    model_config = config.get("model_config", {})
    training_config = config.get("training_config", {})
    hardware_config = config.get("hardware_config", {})
    lora_config = config.get("lora_config", {})
    dataset_config = config.get("dataset_config", {})
    
    # Override flash attention setting to disable it
    hardware_config["use_flash_attention"] = False
    logger.info("Flash attention has been DISABLED due to GPU compatibility issues")
    
    # Verify this is training phase only
    training_phase_only = dataset_config.get("training_phase_only", True)
    if not training_phase_only:
        logger.warning("This script is meant for research training phase only")
        logger.warning("Setting training_phase_only=True")
    
    # Verify dataset is pre-tokenized
    logger.info("IMPORTANT: Using pre-tokenized dataset - No tokenization will be performed")
    
    # Set the output directory
    output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
    os.makedirs(output_dir, exist_ok=True)
    
    # Create training marker
    create_training_marker(output_dir)
    
    try:
        # Print configuration summary
        logger.info("RESEARCH TRAINING PHASE ACTIVE - No output generation")
        logger.info("Configuration Summary:")
        model_name = model_config.get("model_name_or_path")
        logger.info(f"Model: {model_name}")
        logger.info(f"Dataset: {dataset_name if dataset_name != 'phi4-cognitive-dataset' else DEFAULT_DATASET}")
        logger.info(f"Output directory: {output_dir}")
        logger.info("IMPORTANT: Using already 4-bit quantized model - not re-quantizing")
        
        # Load and prepare the dataset
        dataset = load_and_prepare_dataset(dataset_name, config)
        
        # Initialize tokenizer (just for model initialization, not for tokenizing data)
        logger.info("Loading tokenizer (for model initialization only, not for tokenizing data)")
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        tokenizer.pad_token = tokenizer.eos_token
        
        # Initialize model with unsloth
        logger.info("Initializing model with unsloth (preserving 4-bit quantization)")
        max_seq_length = training_config.get("max_seq_length", 2048)
        
        # Create LoRA config directly
        logger.info("Creating LoRA configuration")
        lora_config_obj = LoraConfig(
            r=lora_config.get("r", 16),
            lora_alpha=lora_config.get("lora_alpha", 32),
            lora_dropout=lora_config.get("lora_dropout", 0.05),
            bias=lora_config.get("bias", "none"),
            target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
        )
        
        # Initialize model with our safe loading function
        logger.info("Loading pre-quantized model safely")
        dtype = torch.float16 if hardware_config.get("fp16", True) else None
        model, tokenizer = load_model_safely(model_name, max_seq_length, dtype)
        
        # Try different approaches to apply LoRA
        logger.info("Applying LoRA to model")
        
        # Skip unsloth's method and go directly to PEFT
        logger.info("Using standard PEFT method to apply LoRA")
        from peft import get_peft_model
        model = get_peft_model(model, lora_config_obj)
        logger.info("Successfully applied LoRA with standard PEFT")
            
        # No need to format the dataset - it's already pre-tokenized
        logger.info("Using pre-tokenized dataset - skipping tokenization step")
        training_dataset = dataset
        
        # Configure reporting backends with fallbacks
        reports = []
        if TENSORBOARD_AVAILABLE:
            reports.append("tensorboard")
            logger.info("Tensorboard available and enabled for reporting")
        else:
            logger.warning("Tensorboard not available - metrics won't be logged to tensorboard")
            
        if os.getenv("WANDB_API_KEY"):
            reports.append("wandb")
            logger.info("Wandb API key found, enabling wandb reporting")
        
        # Default to "none" if no reporting backends are available
        if not reports:
            reports = ["none"]
            logger.warning("No reporting backends available - training metrics won't be logged")
        
        # Set up training arguments with flash attention disabled
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=training_config.get("num_train_epochs", 3),
            per_device_train_batch_size=training_config.get("per_device_train_batch_size", 2),
            gradient_accumulation_steps=training_config.get("gradient_accumulation_steps", 4),
            learning_rate=training_config.get("learning_rate", 2e-5),
            lr_scheduler_type=training_config.get("lr_scheduler_type", "cosine"),
            warmup_ratio=training_config.get("warmup_ratio", 0.03),
            weight_decay=training_config.get("weight_decay", 0.01),
            optim=training_config.get("optim", "adamw_torch"),
            logging_steps=training_config.get("logging_steps", 10),
            save_steps=training_config.get("save_steps", 200),
            save_total_limit=training_config.get("save_total_limit", 3),
            fp16=hardware_config.get("fp16", True),
            bf16=hardware_config.get("bf16", False),
            max_grad_norm=training_config.get("max_grad_norm", 0.3),
            report_to=reports,
            logging_first_step=training_config.get("logging_first_step", True),
            disable_tqdm=training_config.get("disable_tqdm", False),
            # Important: Don't remove columns that don't match model's forward method
            remove_unused_columns=False
        )
        
        # Create trainer with pre-tokenized collator
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=training_dataset,
            data_collator=PreTokenizedCollator(pad_token_id=tokenizer.pad_token_id, tokenizer=tokenizer),
        )
        
        # Start training
        logger.info("Starting training - RESEARCH PHASE ONLY")
        trainer.train()
        
        # Save the model
        logger.info(f"Saving model to {output_dir}")
        trainer.save_model(output_dir)
        
        # Save LoRA adapter separately for easier deployment
        lora_output_dir = os.path.join(output_dir, "lora_adapter")
        model.save_pretrained(lora_output_dir)
        logger.info(f"Saved LoRA adapter to {lora_output_dir}")
        
        # Save tokenizer for completeness
        tokenizer_output_dir = os.path.join(output_dir, "tokenizer")
        tokenizer.save_pretrained(tokenizer_output_dir)
        logger.info(f"Saved tokenizer to {tokenizer_output_dir}")
        
        # Copy config file for reference
        with open(os.path.join(output_dir, "training_config.json"), "w") as f:
            json.dump(config, f, indent=2)
        
        logger.info("Training complete - RESEARCH PHASE ONLY")
        return output_dir
    
    finally:
        # Always remove the training marker when done
        remove_training_marker()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune Unsloth/DeepSeek-R1-Distill-Qwen-14B-4bit model (RESEARCH ONLY)")
    parser.add_argument("--config", type=str, default="transformers_config.json", 
                        help="Path to the transformers config JSON file")
    parser.add_argument("--dataset", type=str, default="phi4-cognitive-dataset", 
                        help="Dataset name or path")
    parser.add_argument("--output_dir", type=str, default=None, 
                        help="Output directory for the fine-tuned model")
    
    args = parser.parse_args()
    
    # Run training - Research phase only
    try:
        output_path = train(args.config, args.dataset, args.output_dir)
        print(f"Research training completed. Model saved to: {output_path}")
    except Exception as e:
        logger.error(f"Training failed: {str(e)}")
        remove_training_marker()  # Clean up marker if training fails
        raise