import os import threading from transformers import ( AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, AutoModelForCausalLM, AutoConfig ) import torch import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CodeDebuggerWrapper: def __init__(self, model_name="Girinath11/aiml_code_debug_model"): self.model_name = model_name self.model = None self.tokenizer = None self.model_type = None self._ensure_model() def _ensure_model(self): """Load model and tokenizer with fallback strategies.""" logger.info(f"Loading model {self.model_name} ...") try: # Load tokenizer first self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) logger.info("✅ Tokenizer loaded successfully") # Add pad token if it doesn't exist if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Check model configuration config = AutoConfig.from_pretrained(self.model_name) logger.info(f"Model config type: {type(config).__name__}") # Try different model loading strategies loading_strategies = [ ("AutoModel with trust_remote_code", lambda: AutoModel.from_pretrained( self.model_name, trust_remote_code=True )), ("AutoModelForCausalLM with trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained( self.model_name, trust_remote_code=True )), ("AutoModelForSeq2SeqLM with trust_remote_code", lambda: AutoModelForSeq2SeqLM.from_pretrained( self.model_name, trust_remote_code=True )), ("AutoModel without trust_remote_code", lambda: AutoModel.from_pretrained( self.model_name )), ("AutoModelForCausalLM without trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained( self.model_name )), ] for strategy_name, strategy_func in loading_strategies: try: logger.info(f"Trying: {strategy_name}") self.model = strategy_func() self.model_type = type(self.model).__name__ logger.info(f"✅ Successfully loaded model with {strategy_name}") logger.info(f"Model type: {self.model_type}") break except Exception as e: logger.warning(f"❌ {strategy_name} failed: {str(e)[:100]}...") continue if self.model is None: raise RuntimeError("❌ Failed to load model with any strategy") # Set to evaluation mode self.model.eval() logger.info("✅ Model set to evaluation mode") except Exception as e: logger.error(f"❌ Critical error in model loading: {e}") raise def debug(self, code: str) -> str: """Debug the provided code.""" if not code or not code.strip(): return "❌ Please provide some code to debug." try: # Prepare input with a clear prompt prompt = f"### Task: Debug and fix the following Python code\n\n### Input Code:\n{code}\n\n### Fixed Code:" # Tokenize input inputs = self.tokenizer( prompt, return_tensors="pt", max_length=512, truncation=True, padding=True ) logger.info(f"Input tokenized. Input IDs shape: {inputs['input_ids'].shape}") # Generate response based on model type with torch.no_grad(): if hasattr(self.model, 'generate'): logger.info("Using generate method") # Generation parameters generation_kwargs = { **inputs, 'max_new_tokens': 256, 'num_beams': 3, 'early_stopping': True, 'do_sample': True, 'temperature': 0.7, 'pad_token_id': self.tokenizer.pad_token_id, } # Add eos_token_id if available if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: generation_kwargs['eos_token_id'] = self.tokenizer.eos_token_id try: outputs = self.model.generate(**generation_kwargs) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up the response if prompt in response: response = response.replace(prompt, "").strip() # If response is empty or too short, provide a fallback if not response or len(response.strip()) < 10: response = f"The model processed your code but didn't generate a clear fix. Original code:\n\n{code}\n\n💡 Tip: Check for common issues like indentation, syntax errors, or missing imports." except Exception as gen_error: logger.error(f"Generation error: {gen_error}") response = f"❌ Error during code generation: {str(gen_error)}\n\nOriginal code:\n{code}" else: # For encoder-only models logger.info("Model doesn't support generation. Using encoder-only approach.") outputs = self.model(**inputs) # This is a simplified approach - you might need to add a classification head # or use a different strategy based on your specific model response = f"⚠️ This model type ({self.model_type}) doesn't support text generation directly.\n\nOriginal code:\n{code}\n\n💡 Consider using a generative model (T5, GPT, BART) for code debugging tasks." return response except Exception as e: logger.error(f"Error in debug method: {e}") return f"❌ Error during debugging: {str(e)}\n\nOriginal code:\n{code}"