#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Fine-tuning script for DeepSeek-R1-Distill-Qwen-14B-bnb-4bit using unsloth
RESEARCH TRAINING PHASE ONLY - No output generation
WORKS WITH PRE-TOKENIZED DATASET - No re-tokenization
"""

# Set critical environment variables before any imports
import os
# Configure PyTorch memory allocator for better memory management with multiple GPUs
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["XFORMERS_DISABLED"] = "1"
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"

import json
import logging
import argparse
import numpy as np
from dotenv import load_dotenv
import torch
import sys
from datasets import load_dataset
import transformers
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig
from transformers.data.data_collator import DataCollatorMixin
from peft import LoraConfig
from unsloth import FastLanguageModel

# Set DeepSpeed environment variables to disable MPI
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "9994"
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

# Try to import deepspeed, install mpi4py if needed
try:
    import deepspeed
except ImportError as e:
    if "mpi4py" in str(e):
        logger.warning("mpi4py not found, installing...")
        import subprocess
        try:
            subprocess.check_call([sys.executable, "-m", "pip", "install", "mpi4py"])
            import deepspeed
            logger.info("Successfully installed mpi4py and imported deepspeed")
        except Exception as install_error:
            logger.warning(f"Failed to install mpi4py: {install_error}")
            logger.warning("Continuing without DeepSpeed MPI support")
            # Set a flag to disable DeepSpeed later
            os.environ["DISABLE_DEEPSPEED_MPI"] = "1"
    else:
        logger.error(f"Failed to import deepspeed: {e}")
        raise

# Disable all attention optimizations that might cause issues
os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["XFORMERS_DISABLED"] = "1"

# Completely disable xformers by removing it from sys.modules if it's loaded
if 'xformers' in sys.modules:
    del sys.modules['xformers']
if 'xformers.ops' in sys.modules:
    del sys.modules['xformers.ops']

# Patch Python's import system to prevent xformers from being imported
class XFormersBlocker:
    def __init__(self, original_importer):
        self.original_importer = original_importer
    
    def find_spec(self, fullname, path, target=None):
        if 'xformers' in fullname:
            # Block xformers imports
            return None
        # Use the original importer for everything else
        return self.original_importer.find_spec(fullname, path, target)

# Add our import blocker to sys.meta_path
sys.meta_path.insert(0, XFormersBlocker(sys.meta_path[0]))

# Configure logging first
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler("training.log")
    ]
)
logger = logging.getLogger(__name__)

# Make sure torch is installed and available before proceeding
try:
    logger.info("Importing torch...")
    import torch
    logger.info(f"PyTorch version: {torch.__version__}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        logger.info(f"CUDA version: {torch.version.cuda}")
        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
except ImportError:
    logger.error("PyTorch not found. Installing torch first...")
    try:
        import subprocess
        import sys
        subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
        logger.info("PyTorch installed successfully. Importing...")
        import torch
        logger.info(f"PyTorch version: {torch.__version__}")
    except Exception as e:
        logger.error(f"Failed to install PyTorch: {e}")
        logger.error("Cannot proceed without PyTorch. Exiting.")
        raise

# Now try to install flash-attention (for systems that support it)
try:
    import subprocess
    import sys
    
    # Make sure torch is installed before attempting flash-attn
    try:
        logger.info("Ensuring PyTorch is installed before flash-attention...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "--quiet"])
        logger.info("PyTorch installation verified")
    except Exception as torch_error:
        logger.warning(f"PyTorch installation check failed: {torch_error}")
        logger.info("Will continue with flash-attention installation anyway")
    
    logger.info("Attempting to install flash-attention...")
    
    # Try multiple installation approaches for flash-attention
    try:
        # First try with pip install
        logger.info("Trying standard pip install for flash-attn")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn"])
    except Exception as pip_error:
        logger.warning(f"Standard installation failed: {pip_error}")
        logger.info("Trying alternative installation approach...")
        
        # Try the PIP_EXTRA_INDEX_URL approach
        env = os.environ.copy()
        if "PIP_EXTRA_INDEX_URL" not in env:
            env["PIP_EXTRA_INDEX_URL"] = "https://download.pytorch.org/whl/cu118"
        
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", "flash-attn"], 
            env=env
        )
    
    logger.info("Successfully installed flash-attention")
except Exception as e:
    logger.warning(f"Failed to install flash-attention: {e}")
    logger.info("Continuing without flash-attention")

# Check if flash attention was successfully installed
flash_attention_available = False
try:
    import flash_attn
    flash_attention_available = True
    logger.info(f"Flash Attention will be used (version: {flash_attn.__version__})")
    # We'll handle flash attention configuration during model loading
except ImportError:
    logger.info("Flash Attention not available, will use standard attention mechanism")

# Check if tensorboard is available
try:
    import tensorboard
    TENSORBOARD_AVAILABLE = True
except ImportError:
    TENSORBOARD_AVAILABLE = False
    print("Tensorboard not available. Will skip tensorboard logging.")

# Default dataset path - use the correct path with username
DEFAULT_DATASET = "George-API/phi4-cognitive-dataset"

def load_config(config_path):
    """Load the transformers config from JSON file"""
    logger.info(f"Loading config from {config_path}")
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

def load_and_prepare_dataset(dataset_name, config):
    """
    Load and prepare the dataset for fine-tuning.
    Sort entries by prompt_number as required.
    Handles both pre-tokenized and string content.
    """
    # Use the default dataset path if no specific path is provided
    if dataset_name == "phi4-cognitive-dataset":
        dataset_name = DEFAULT_DATASET
        
    logger.info(f"Loading dataset: {dataset_name}")
    
    try:
        # Load dataset
        dataset = load_dataset(dataset_name)
        
        # Extract the split we want to use (usually 'train')
        if 'train' in dataset:
            dataset = dataset['train']
        
        # Get the dataset config
        dataset_config = config.get("dataset_config", {})
        sort_field = dataset_config.get("sort_by_field", "prompt_number")
        
        # Always sort in ascending order by prompt_number
        logger.info(f"Sorting dataset by {sort_field} in ascending order")
        dataset = dataset.sort(sort_field)
        
        # Verify sorting
        if len(dataset) > 1:
            first_prompt = dataset[0].get(sort_field, None)
            last_prompt = dataset[-1].get(sort_field, None)
            logger.info(f"Dataset sorted: first {sort_field}={first_prompt}, last {sort_field}={last_prompt}")
            
            # Additional verification of a few samples
            sample_indices = [0, len(dataset)//2, len(dataset)-1]
            sample_prompts = [dataset[i].get(sort_field, None) for i in sample_indices]
            logger.info(f"Sample prompt numbers: {sample_prompts}")
            
            # Verify order is ascending
            if not all(sample_prompts[i] <= sample_prompts[i+1] for i in range(len(sample_prompts)-1)):
                logger.warning("Dataset may not be properly sorted! Please check the ordering.")
        
        # Print dataset structure for debugging
        logger.info(f"Dataset loaded with {len(dataset)} entries")
        logger.info(f"Dataset columns: {dataset.column_names}")
        
        # Print a sample entry to understand structure
        if len(dataset) > 0:
            sample = dataset[0]
            logger.info(f"Sample entry structure: {list(sample.keys())}")
            
            # Check if dataset is pre-tokenized or contains string content
            is_pre_tokenized = False
            
            if 'input_ids' in sample and isinstance(sample['input_ids'], list) and all(isinstance(x, int) for x in sample['input_ids']):
                logger.info("Dataset appears to be pre-tokenized with input_ids field")
                is_pre_tokenized = True
            elif 'conversations' in sample:
                logger.info(f"Sample conversations structure: {sample['conversations'][:1]}")
                
                # Check if conversations contain pre-tokenized data
                if isinstance(sample['conversations'], list) and len(sample['conversations']) > 0:
                    conv = sample['conversations'][0]
                    if isinstance(conv, dict) and 'input_ids' in conv and isinstance(conv['input_ids'], list):
                        logger.info("Dataset appears to be pre-tokenized in conversations.input_ids")
                        is_pre_tokenized = True
                    elif isinstance(conv, dict) and 'content' in conv:
                        content = conv['content']
                        if isinstance(content, list) and all(isinstance(x, int) for x in content):
                            logger.info("Dataset appears to be pre-tokenized in conversations.content")
                            is_pre_tokenized = True
                        else:
                            logger.info("Dataset appears to contain string content that will need tokenization")
            
            if is_pre_tokenized:
                logger.info("Using pre-tokenized dataset - tokenizer will only be used as fallback")
            else:
                logger.info("Dataset contains string content - tokenizer will be used")
                
        return dataset
    
    except Exception as e:
        logger.error(f"Error loading dataset: {str(e)}")
        logger.info("Available datasets in the Hub:")
        # Print a more helpful error message
        print(f"Failed to load dataset: {dataset_name}")
        print(f"Make sure the dataset exists and is accessible.")
        print(f"If it's a private dataset, ensure your HF_TOKEN has access to it.")
        raise

def tokenize_string(text, tokenizer):
    """Tokenize a string using the provided tokenizer"""
    if not text:
        return []
    
    # Tokenize the text
    tokens = tokenizer.encode(text, add_special_tokens=False)
    return tokens

# Data collator for pre-tokenized dataset
class PreTokenizedCollator(DataCollatorMixin):
    """
    Data collator that can handle both pre-tokenized datasets and string content.
    Will tokenize strings if necessary, but logs warnings.
    """
    def __init__(self, pad_token_id=0, tokenizer=None):
        self.pad_token_id = pad_token_id
        self.tokenizer = tokenizer  # Keep a reference to the tokenizer for fallback tokenization
        
    def __call__(self, features):
        # Print a sample feature to understand structure
        if len(features) > 0:
            logger.info(f"Sample feature keys: {list(features[0].keys())}")
            
        # Extract input_ids from conversations if needed
        processed_features = []
        for feature in features:
            # If input_ids is directly available, use it without tokenization
            if 'input_ids' in feature and isinstance(feature['input_ids'], list):
                # Already tokenized, no processing needed
                processed_features.append(feature)
                continue
                
            # If input_ids is not directly available, try to extract from conversations
            if 'input_ids' not in feature and 'conversations' in feature:
                # Extract from conversations based on your dataset structure
                conversations = feature['conversations']
                
                # Debug the conversations structure (only for first batch)
                if len(processed_features) == 0:
                    logger.info(f"Conversations type: {type(conversations)}")
                    if isinstance(conversations, list) and len(conversations) > 0:
                        logger.info(f"First conversation type: {type(conversations[0])}")
                
                # Try different approaches to extract input_ids
                if isinstance(conversations, list) and len(conversations) > 0:
                    # Case 1: If conversations is a list of dicts with 'input_ids' field (pre-tokenized)
                    if isinstance(conversations[0], dict) and 'input_ids' in conversations[0]:
                        feature['input_ids'] = conversations[0]['input_ids']
                    
                    # Case 2: If conversations itself contains the input_ids (pre-tokenized)
                    elif all(isinstance(x, int) for x in conversations):
                        feature['input_ids'] = conversations
                    
                    # Case 3: If conversations is a list of dicts with 'content' field
                    elif isinstance(conversations[0], dict) and 'content' in conversations[0]:
                        content = conversations[0]['content']
                        
                        # If content is already a list of integers, use it directly
                        if isinstance(content, list) and all(isinstance(x, int) for x in content):
                            feature['input_ids'] = content
                        # If content is a string, tokenize it with a warning
                        elif isinstance(content, str) and self.tokenizer:
                            logger.warning("Found string content in dataset. Tokenizing as fallback.")
                            feature['input_ids'] = self.tokenizer.encode(content, add_special_tokens=False)
                        else:
                            logger.warning(f"Unexpected content format: {type(content)}")
                            continue
                    
                    # Case 4: If conversations is a list of strings
                    elif all(isinstance(x, str) for x in conversations) and self.tokenizer:
                        # Join all strings and tokenize
                        logger.warning("Found string conversations in dataset. Tokenizing as fallback.")
                        full_text = " ".join(conversations)
                        feature['input_ids'] = self.tokenizer.encode(full_text, add_special_tokens=False)
            
            # Ensure input_ids is a list of integers
            if 'input_ids' in feature:
                # If input_ids is a string, tokenize it
                if isinstance(feature['input_ids'], str) and self.tokenizer:
                    logger.warning("Found string input_ids in dataset. Tokenizing as fallback.")
                    feature['input_ids'] = self.tokenizer.encode(feature['input_ids'], add_special_tokens=False)
                # If input_ids is not a list, convert it
                elif not isinstance(feature['input_ids'], list):
                    try:
                        feature['input_ids'] = list(feature['input_ids'])
                    except:
                        logger.error(f"Could not convert input_ids to list: {type(feature['input_ids'])}")
                        continue
            else:
                logger.warning("No input_ids found in this example. Skipping.")
                continue
            
            processed_features.append(feature)
            
        # If we still don't have input_ids, log an error
        if len(processed_features) == 0:
            logger.error("No valid examples found in batch. Check dataset format.")
            raise ValueError("No valid examples found. Please check dataset structure.")
        
        if 'input_ids' not in processed_features[0]:
            logger.error(f"Could not find input_ids in features. Available keys: {list(processed_features[0].keys())}")
            if 'conversations' in processed_features[0]:
                logger.error(f"Conversations structure: {processed_features[0]['conversations'][:1]}")
            raise ValueError("Could not find input_ids in dataset. Please check dataset structure.")
            
        # Determine max length in this batch
        batch_max_len = max(len(x["input_ids"]) for x in processed_features)
        
        # Initialize batch tensors
        batch = {
            "input_ids": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * self.pad_token_id,
            "attention_mask": torch.zeros((len(processed_features), batch_max_len), dtype=torch.long),
            "labels": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * -100  # -100 is ignored in loss
        }
        
        # Fill batch tensors
        for i, feature in enumerate(processed_features):
            input_ids = feature["input_ids"]
            seq_len = len(input_ids)
            
            # Convert to tensor if it's a list
            if isinstance(input_ids, list):
                input_ids = torch.tensor(input_ids, dtype=torch.long)
                
            # Copy data to batch tensors
            batch["input_ids"][i, :seq_len] = input_ids
            batch["attention_mask"][i, :seq_len] = 1
            
            # If there are labels, use them, otherwise use input_ids
            if "labels" in feature:
                labels = feature["labels"]
                if isinstance(labels, list):
                    labels = torch.tensor(labels, dtype=torch.long)
                batch["labels"][i, :len(labels)] = labels
            else:
                batch["labels"][i, :seq_len] = input_ids
        
        return batch

def create_training_marker(output_dir):
    """Create a marker file to indicate training is active"""
    # Create in current directory for app.py to find
    with open("TRAINING_ACTIVE", "w") as f:
        f.write(f"Training active in {output_dir}")
    
    # Also create in output directory
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "RESEARCH_TRAINING_ONLY"), "w") as f:
        f.write("This model is for research training only. No interactive outputs.")

def remove_training_marker():
    """Remove the training marker file"""
    if os.path.exists("TRAINING_ACTIVE"):
        os.remove("TRAINING_ACTIVE")
        logger.info("Removed training active marker")

def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attention=False, use_deepspeed=False):
    """
    Load the model directly with HuggingFace, bypassing Unsloth optimizations
    to avoid memory-efficient attention issues
    """
    logger.info(f"Loading model: {model_name}")
    
    # Create BitsAndBytesConfig for 4-bit quantization
    from transformers import BitsAndBytesConfig
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True
    )
    
    # Force eager implementation to avoid BMGHK format issues
    attn_implementation = "eager"
    logger.info(f"Forcing eager attention implementation to avoid BMGHK format issues")
    
    # Skip Unsloth and use standard HuggingFace loading
    logger.info("Bypassing Unsloth optimizations to avoid memory-efficient attention issues")
    
    # Check available GPUs
    gpu_count = torch.cuda.device_count()
    logger.info(f"Found {gpu_count} GPU(s) available")
    
    # Load with standard HuggingFace
    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    
    # Set attention implementation in config
    config.attn_implementation = attn_implementation
    
    # Disable any custom attention mechanisms
    if hasattr(config, "use_flash_attention"):
        config.use_flash_attention = False
    if hasattr(config, "use_memory_efficient_attention"):
        config.use_memory_efficient_attention = False
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    
    # Set device mapping based on whether DeepSpeed is used
    # When using DeepSpeed, we should use 'cpu' or 'meta' for initial loading
    # to avoid OOM issues, as DeepSpeed will handle the device placement
    if use_deepspeed:
        logger.info("Using DeepSpeed - loading model initially on CPU to avoid OOM issues")
        device_map = "cpu"  # Load on CPU first, DeepSpeed will handle distribution
    else:
        # Always use auto device mapping for cloud hardware when not using DeepSpeed
        device_map = "auto"
        
    logger.info(f"Using device_map={device_map} for initial model loading")
    
    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        config=config,
        device_map=device_map,
        torch_dtype=dtype or torch.float16,
        quantization_config=bnb_config,
        trust_remote_code=True,
        attn_implementation=attn_implementation
    )
    
    logger.info("Model loaded successfully with standard HF loading")
    
    # If using DeepSpeed, ensure model is properly prepared
    if use_deepspeed:
        logger.info("Model loaded on CPU - DeepSpeed will handle device placement during training")
    
    return model, tokenizer

def train(config_path, dataset_name, output_dir):
    """Main training function - RESEARCH TRAINING PHASE ONLY"""
    # Load environment variables
    load_dotenv()
    config = load_config(config_path)
    
    # Set CUDA launch blocking for better error reporting
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    
    # Try to unload xformers if it's loaded
    if 'xformers' in sys.modules:
        logger.info("Removing xformers from sys.modules")
        del sys.modules['xformers']
    
    # Patch torch.nn.functional to avoid memory_efficient_attention
    try:
        import torch.nn.functional as F
        if hasattr(F, 'scaled_dot_product_attention'):
            logger.info("Patching torch.nn.functional.scaled_dot_product_attention")
            original_sdpa = F.scaled_dot_product_attention
            
            def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
                # Force disable memory efficient attention
                logger.info("Using safe scaled_dot_product_attention (no xformers)")
                return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
            
            F.scaled_dot_product_attention = safe_sdpa
    except Exception as e:
        logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
    
    # Extract configs
    model_config = config.get("model_config", {})
    training_config = config.get("training_config", {})
    hardware_config = config.get("hardware_config", {})
    lora_config = config.get("lora_config", {})
    dataset_config = config.get("dataset_config", {})
    
    # Set the output directory
    output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
    os.makedirs(output_dir, exist_ok=True)
    
    # Create training marker
    create_training_marker(output_dir)
    
    try:
        # Print configuration summary
        logger.info("RESEARCH TRAINING PHASE ACTIVE - No output generation")
        logger.info("Configuration Summary:")
        model_name = model_config.get("model_name_or_path")
        logger.info(f"Model: {model_name}")
        logger.info(f"Dataset: {dataset_name if dataset_name != 'phi4-cognitive-dataset' else DEFAULT_DATASET}")
        logger.info(f"Output directory: {output_dir}")
        logger.info("IMPORTANT: Using already 4-bit quantized model - not re-quantizing")
        
        # Check GPU availability
        gpu_count = torch.cuda.device_count()
        logger.info(f"Found {gpu_count} GPU(s) available")
        for i in range(gpu_count):
            logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        
        # Load and prepare the dataset
        dataset = load_and_prepare_dataset(dataset_name, config)
        
        # Initialize tokenizer (just for model initialization, not for tokenizing data)
        logger.info("Loading tokenizer (for model initialization only, not for tokenizing data)")
        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        tokenizer.pad_token = tokenizer.eos_token
        
        # Initialize model
        logger.info("Initializing model (preserving 4-bit quantization)")
        
        # Use full sequence length of 2048 as required for pre-tokenized dataset
        max_seq_length = training_config.get("max_seq_length", 2048)
        logger.info(f"Using sequence length: {max_seq_length} as required for pre-tokenized dataset")
        
        # Create LoRA config directly
        logger.info("Creating LoRA configuration")
        lora_config_obj = LoraConfig(
            r=lora_config.get("r", 16),
            lora_alpha=lora_config.get("lora_alpha", 32),
            lora_dropout=lora_config.get("lora_dropout", 0.05),
            bias=lora_config.get("bias", "none"),
            target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
        )
        
        # Force eager attention implementation
        use_flash_attention = False  # Override to force eager implementation
        
        # Initialize ds_config_path to None before checking
        ds_config_path = None
        
        # Optimize batch size for multi-GPU setup
        # For 4x L4 GPUs (24GB each), we can safely use a larger batch size
        per_device_train_batch_size = 4 if gpu_count >= 4 else 2
        logger.info(f"Using batch size: {per_device_train_batch_size} per device (effective batch size: {per_device_train_batch_size * gpu_count * training_config.get('gradient_accumulation_steps', 4)})")
        
        # Check if DeepSpeed config is available and if MPI is disabled
        deepspeed_config = config.get("deepspeed_config", None)
        if deepspeed_config and os.environ.get("DISABLE_DEEPSPEED_MPI", "0") != "1":
            logger.info("DeepSpeed configuration found - enabling DeepSpeed for distributed training")
            
            # Create a temporary DeepSpeed config file
            ds_config_path = os.path.join(output_dir, "ds_config_temp.json")
            
            # Update DeepSpeed config with dynamic values
            if isinstance(deepspeed_config.get("train_micro_batch_size_per_gpu"), str) and deepspeed_config.get("train_micro_batch_size_per_gpu") == "auto":
                deepspeed_config["train_micro_batch_size_per_gpu"] = per_device_train_batch_size
                
            if isinstance(deepspeed_config.get("train_batch_size"), str) and deepspeed_config.get("train_batch_size") == "auto":
                deepspeed_config["train_batch_size"] = per_device_train_batch_size * gpu_count
            
            # Ensure communication backend is set to avoid MPI
            if "communication_data_type" not in deepspeed_config:
                deepspeed_config["communication_data_type"] = "fp16"
            
            # Write the DeepSpeed config to a file
            with open(ds_config_path, 'w') as f:
                json.dump(deepspeed_config, f, indent=2)
            
            logger.info(f"Created DeepSpeed config at {ds_config_path}")
            logger.info(f"DeepSpeed ZeRO Stage: {deepspeed_config.get('zero_optimization', {}).get('stage', 'Not specified')}")
            
            # Enable CPU offloading if configured
            if deepspeed_config.get("zero_optimization", {}).get("offload_optimizer", {}).get("device") == "cpu":
                logger.info("DeepSpeed CPU offloading enabled for optimizer states")
            
            # Set using_deepspeed flag
            using_deepspeed = True
        elif os.environ.get("DISABLE_DEEPSPEED_MPI", "0") == "1":
            logger.warning("DeepSpeed MPI support is disabled due to missing mpi4py. Continuing without DeepSpeed.")
            ds_config_path = None
            using_deepspeed = False
        else:
            logger.warning("No DeepSpeed configuration found - continuing without DeepSpeed")
            ds_config_path = None
            using_deepspeed = False
        
        # Initialize model with our safe loading function
        logger.info("Loading pre-quantized model with eager attention")
        dtype = torch.float16 if hardware_config.get("fp16", True) else None
        model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention, use_deepspeed=using_deepspeed)
        
        # Disable generation capabilities for research training
        logger.info("Disabling generation capabilities - Research training only")
        model.config.is_decoder = False
        model.config.task_specific_params = None
        
        # Apply LoRA to model
        logger.info("Applying LoRA to model")
        from peft import get_peft_model
        model = get_peft_model(model, lora_config_obj)
        logger.info("Successfully applied LoRA with standard PEFT")
        
        # Explicitly set attention implementation in model config again after PEFT
        model.config.attn_implementation = "eager"
        
        # No need to format the dataset - it's already pre-tokenized
        logger.info("Using dataset with flexible tokenization handling")
        logger.info("Will use pre-tokenized data if available, or tokenize strings as fallback")
        training_dataset = dataset
        
        # Configure reporting backends with fallbacks
        reports = []
        if TENSORBOARD_AVAILABLE:
            reports.append("tensorboard")
            logger.info("Tensorboard available and enabled for reporting")
        else:
            logger.warning("Tensorboard not available - metrics won't be logged to tensorboard")
            
        if os.getenv("WANDB_API_KEY"):
            reports.append("wandb")
            logger.info("Wandb API key found, enabling wandb reporting")
        
        # Default to "none" if no reporting backends are available
        if not reports:
            reports = ["none"]
            logger.warning("No reporting backends available - training metrics won't be logged")
        
        training_args_dict = {
            "output_dir": output_dir,
            "num_train_epochs": training_config.get("num_train_epochs", 3),
            "per_device_train_batch_size": per_device_train_batch_size,
            "gradient_accumulation_steps": training_config.get("gradient_accumulation_steps", 4),
            "learning_rate": training_config.get("learning_rate", 2e-5),
            "lr_scheduler_type": training_config.get("lr_scheduler_type", "cosine"),
            "warmup_ratio": training_config.get("warmup_ratio", 0.03),
            "weight_decay": training_config.get("weight_decay", 0.01),
            "optim": training_config.get("optim", "adamw_torch"),
            "logging_steps": training_config.get("logging_steps", 10),
            "save_steps": training_config.get("save_steps", 200),
            "save_total_limit": training_config.get("save_total_limit", 3),
            "fp16": hardware_config.get("fp16", True),
            "bf16": hardware_config.get("bf16", False),
            "max_grad_norm": training_config.get("max_grad_norm", 0.3),
            "report_to": reports,
            "logging_first_step": training_config.get("logging_first_step", True),
            "disable_tqdm": training_config.get("disable_tqdm", False),
            "remove_unused_columns": False,
            "seed": 42,
            "dataloader_num_workers": 4,  # Use multiple workers for data loading
        }
        
        # Add DeepSpeed config path if available and enabled
        if using_deepspeed and ds_config_path:
            logger.info("Adding DeepSpeed configuration to training arguments")
            training_args_dict["deepspeed"] = ds_config_path
        else:
            logger.info("DeepSpeed is disabled - using standard distributed training")
        
        # Create TrainingArguments with validated parameters
        try:
            training_args = TrainingArguments(**training_args_dict)
        except Exception as e:
            logger.error(f"Failed to create training arguments with DeepSpeed: {e}")
            if "deepspeed" in training_args_dict:
                logger.warning("Removing DeepSpeed configuration and trying again")
                del training_args_dict["deepspeed"]
                training_args = TrainingArguments(**training_args_dict)
                using_deepspeed = False
        
        # Create trainer with pre-tokenized collator
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=training_dataset,
            data_collator=PreTokenizedCollator(pad_token_id=tokenizer.pad_token_id, tokenizer=tokenizer),
        )
        
        # Start training
        logger.info("Starting training - RESEARCH PHASE ONLY")
        trainer.train()
        
        # Save the model
        logger.info(f"Saving model to {output_dir}")
        trainer.save_model(output_dir)
        
        # Save LoRA adapter separately for easier deployment
        lora_output_dir = os.path.join(output_dir, "lora_adapter")
        model.save_pretrained(lora_output_dir)
        logger.info(f"Saved LoRA adapter to {lora_output_dir}")
        
        # Save tokenizer for completeness
        tokenizer_output_dir = os.path.join(output_dir, "tokenizer")
        tokenizer.save_pretrained(tokenizer_output_dir)
        logger.info(f"Saved tokenizer to {tokenizer_output_dir}")
        
        # Copy config file for reference
        with open(os.path.join(output_dir, "training_config.json"), "w") as f:
            json.dump(config, f, indent=2)
        
        logger.info("Training complete - RESEARCH PHASE ONLY")
        return output_dir
    
    finally:
        # Always remove the training marker when done
        remove_training_marker()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune Unsloth/DeepSeek-R1-Distill-Qwen-14B-4bit model (RESEARCH ONLY)")
    parser.add_argument("--config", type=str, default="transformers_config.json", 
                        help="Path to the transformers config JSON file")
    parser.add_argument("--dataset", type=str, default="phi4-cognitive-dataset", 
                        help="Dataset name or path")
    parser.add_argument("--output_dir", type=str, default=None, 
                        help="Output directory for the fine-tuned model")
    parser.add_argument("--use_flash_attention", action="store_true",
                        help="Use Flash Attention if available (NOT RECOMMENDED)")
    
    args = parser.parse_args()
    
    # Override flash attention setting to force eager implementation
    args.use_flash_attention = False
    
    # Run training - Research phase only
    try:
        output_path = train(args.config, args.dataset, args.output_dir)
        print(f"Research training completed. Model saved to: {output_path}")
    except Exception as e:
        logger.error(f"Training failed: {str(e)}")
        remove_training_marker()  # Clean up marker if training fails
        raise