import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import random
from typing import Dict, Any, List

class QwenTextGenerator:
    """Text generation using Qwen2.5-0.5B-Instruct for monster traits and dialogue"""
    
    def __init__(self, device: str = "cuda"):
        self.device = device if torch.cuda.is_available() else "cpu"
        self.model = None
        self.tokenizer = None
        self.model_id = "Qwen/Qwen2.5-0.5B-Instruct"
        
        # Generation parameters
        self.max_new_tokens = 150
        self.temperature = 0.8
        self.top_p = 0.9
        
        # Monster trait templates
        self.trait_categories = {
            'elements': ['fire', 'water', 'earth', 'wind', 'electric', 'ice', 'nature', 'dark', 'light', 'neutral'],
            'personalities': ['brave', 'timid', 'aggressive', 'gentle', 'playful', 'serious', 'loyal', 'independent', 'curious', 'protective'],
            'body_types': ['bipedal', 'quadruped', 'serpentine', 'avian', 'aquatic', 'insectoid', 'humanoid', 'amorphous'],
            'sizes': ['tiny', 'small', 'medium', 'large', 'giant'],
            'special_features': ['wings', 'horns', 'tail', 'spikes', 'fur', 'scales', 'armor', 'crystals', 'flames', 'aura']
        }
    
    def load_model(self):
        """Lazy load the text generation model"""
        if self.model is None:
            try:
                # Load tokenizer
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
                
                # Model configuration
                torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
                
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_id,
                    torch_dtype=torch_dtype,
                    device_map="auto" if self.device == "cuda" else None,
                    low_cpu_mem_usage=True
                )
                
                if self.device == "cpu":
                    self.model.to(self.device)
                
            except Exception as e:
                print(f"Failed to load text generation model: {e}")
                raise
    
    def generate_traits(self, description: str) -> Dict[str, Any]:
        """Generate monster traits from description"""
        try:
            self.load_model()
            
            # Create prompt for trait generation
            prompt = self._create_trait_prompt(description)
            
            # Generate response
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    temperature=self.temperature,
                    top_p=self.top_p,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            
            response = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            
            # Parse traits from response
            traits = self._parse_traits(response, description)
            
            return traits
            
        except Exception as e:
            print(f"Error generating traits: {e}")
            return self._generate_fallback_traits(description)
    
    def generate_dialogue(self, traits: Dict[str, Any]) -> str:
        """Generate monster dialogue (emoji + numbers)"""
        try:
            # Create emoji dialogue based on personality and mood
            personality = traits.get('personality', 'neutral')
            
            # Emoji mapping for personalities
            emoji_map = {
                'brave': ['💪', '🔥', '⚔️', '🛡️'],
                'timid': ['😰', '🥺', '💦', '❓'],
                'aggressive': ['😤', '💢', '🔥', '⚡'],
                'gentle': ['💚', '🌸', '✨', '🌟'],
                'playful': ['😊', '🎮', '🎯', '🎪'],
                'serious': ['🤖', '📊', '⚡', '💯'],
                'loyal': ['💖', '🤝', '🛡️', '⭐'],
                'independent': ['🚀', '🌍', '🔮', '💫'],
                'curious': ['🔍', '❓', '💡', '🌟'],
                'protective': ['🛡️', '💪', '🏰', '⚔️']
            }
            
            # Get appropriate emojis
            emojis = emoji_map.get(personality, ['🤖', '💚', '✨'])
            selected_emojis = random.sample(emojis, min(2, len(emojis)))
            
            # Generate status numbers (representing monster's current state)
            hp_percent = random.randint(70, 100)
            happiness = random.randint(60, 95)
            energy = random.randint(50, 90)
            
            # Create dialogue
            dialogue = f"{selected_emojis[0]}{selected_emojis[1] if len(selected_emojis) > 1 else '💚'}"
            dialogue += f"{hp_percent}️⃣{happiness}️⃣"
            
            return dialogue
            
        except Exception as e:
            print(f"Error generating dialogue: {e}")
            return "🤖💚9️⃣0️⃣"
    
    def _create_trait_prompt(self, description: str) -> str:
        """Create prompt for trait generation"""
        prompt = f"""<|im_start|>system
You are a creative game designer creating unique digital monsters. Generate detailed traits for a monster based on the description.
<|im_end|>
<|im_start|>user
Create traits for this monster: {description}

Include: name, species, element, personality, appearance details, and special abilities.
<|im_end|>
<|im_start|>assistant
"""
        return prompt
    
    def _parse_traits(self, response: str, original_description: str) -> Dict[str, Any]:
        """Parse traits from model response"""
        traits = {
            'description': original_description,
            'raw_response': response
        }
        
        # Extract name
        if "name:" in response.lower():
            name_start = response.lower().find("name:") + 5
            name_end = response.find("\n", name_start)
            if name_end == -1:
                name_end = len(response)
            traits['name'] = response[name_start:name_end].strip()
        else:
            traits['name'] = self._generate_name()
        
        # Extract or assign element
        element_found = False
        for element in self.trait_categories['elements']:
            if element in response.lower():
                traits['element'] = element
                element_found = True
                break
        
        if not element_found:
            traits['element'] = random.choice(self.trait_categories['elements'])
        
        # Extract or assign personality
        personality_found = False
        for personality in self.trait_categories['personalities']:
            if personality in response.lower():
                traits['personality'] = personality
                personality_found = True
                break
        
        if not personality_found:
            traits['personality'] = random.choice(self.trait_categories['personalities'])
        
        # Extract appearance
        traits['appearance'] = self._extract_appearance(response)
        
        # Extract abilities
        traits['abilities'] = self._extract_abilities(response, traits['element'])
        
        # Add color scheme based on element
        traits['color_scheme'] = self._get_color_scheme(traits['element'])
        
        return traits
    
    def _generate_name(self) -> str:
        """Generate a random monster name"""
        prefixes = ['Pyro', 'Aqua', 'Terra', 'Aero', 'Volt', 'Cryo', 'Flora', 'Shadow', 'Lumi', 'Neo']
        suffixes = ['mon', 'beast', 'guard', 'wing', 'claw', 'fang', 'horn', 'tail', 'byte', 'spark']
        
        return random.choice(prefixes) + random.choice(suffixes)
    
    def _extract_appearance(self, response: str) -> str:
        """Extract appearance description"""
        appearance_keywords = ['appearance', 'looks like', 'resembles', 'body', 'color', 'size']
        
        for keyword in appearance_keywords:
            if keyword in response.lower():
                start = response.lower().find(keyword)
                end = response.find('.', start)
                if end == -1:
                    end = response.find('\n', start)
                if end == -1:
                    end = len(response)
                
                return response[start:end].strip()
        
        # Fallback appearance
        body_type = random.choice(self.trait_categories['body_types'])
        size = random.choice(self.trait_categories['sizes'])
        feature = random.choice(self.trait_categories['special_features'])
        
        return f"A {size} {body_type} creature with {feature}"
    
    def _extract_abilities(self, response: str, element: str) -> List[str]:
        """Extract or generate abilities"""
        abilities = []
        
        ability_keywords = ['ability', 'power', 'skill', 'can', 'capable']
        for keyword in ability_keywords:
            if keyword in response.lower():
                # Try to extract abilities from response
                start = response.lower().find(keyword)
                end = response.find('.', start)
                if end > start:
                    ability_text = response[start:end]
                    abilities.append(ability_text.strip())
        
        # If no abilities found, generate based on element
        if not abilities:
            element_abilities = {
                'fire': ['Flame Burst', 'Heat Wave', 'Ember Shield'],
                'water': ['Aqua Jet', 'Bubble Shield', 'Tidal Wave'],
                'earth': ['Rock Throw', 'Earthquake', 'Stone Armor'],
                'wind': ['Gust', 'Tornado', 'Wind Shield'],
                'electric': ['Thunder Shock', 'Static Field', 'Lightning Speed'],
                'ice': ['Ice Beam', 'Frost Armor', 'Blizzard'],
                'nature': ['Vine Whip', 'Healing Bloom', 'Nature\'s Guard'],
                'dark': ['Shadow Strike', 'Dark Pulse', 'Void Shield'],
                'light': ['Light Beam', 'Healing Light', 'Radiant Shield'],
                'neutral': ['Tackle', 'Defense Curl', 'Focus']
            }
            
            abilities = random.sample(
                element_abilities.get(element, element_abilities['neutral']), 
                2
            )
        
        return abilities
    
    def _get_color_scheme(self, element: str) -> str:
        """Get color scheme based on element"""
        color_schemes = {
            'fire': 'red and orange with yellow accents',
            'water': 'blue and cyan with white highlights',
            'earth': 'brown and green with stone textures',
            'wind': 'white and light blue with swirling patterns',
            'electric': 'yellow and blue with sparking effects',
            'ice': 'light blue and white with crystalline features',
            'nature': 'green and brown with leaf patterns',
            'dark': 'black and purple with shadow effects',
            'light': 'white and gold with glowing aura',
            'neutral': 'gray and silver with balanced tones'
        }
        
        return color_schemes.get(element, 'varied colors with unique patterns')
    
    def _generate_fallback_traits(self, description: str) -> Dict[str, Any]:
        """Generate fallback traits if model fails"""
        element = random.choice(self.trait_categories['elements'])
        personality = random.choice(self.trait_categories['personalities'])
        
        return {
            'name': self._generate_name(),
            'species': 'Digital Monster',
            'element': element,
            'personality': personality,
            'appearance': f"A unique {random.choice(self.trait_categories['sizes'])} digital creature",
            'color_scheme': self._get_color_scheme(element),
            'abilities': self._extract_abilities("", element),
            'description': description
        }
    
    def to(self, device: str):
        """Move model to specified device"""
        self.device = device
        if self.model:
            self.model.to(device)
    
    def __del__(self):
        """Cleanup when object is destroyed"""
        if self.model:
            del self.model
        if self.tokenizer:
            del self.tokenizer
        torch.cuda.empty_cache()