File size: 6,864 Bytes
87ce049
 
 
bde5911
 
 
 
 
 
 
 
 
 
 
 
 
 
87ce049
bde5911
87ce049
 
bde5911
 
87ce049
bde5911
87ce049
bde5911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87ce049
bde5911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import threading

from transformers import (
    AutoTokenizer, 
    AutoModel, 
    AutoModelForSeq2SeqLM, 
    AutoModelForCausalLM,
    AutoConfig
)
import torch
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class CodeDebuggerWrapper:
    def __init__(self, model_name="Girinath11/aiml_code_debug_model"):
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.model_type = None
        self._ensure_model()
    
    def _ensure_model(self):
        """Load model and tokenizer with fallback strategies."""
        logger.info(f"Loading model {self.model_name} ...")
        
        try:
            # Load tokenizer first
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            logger.info("βœ… Tokenizer loaded successfully")
            
            # Add pad token if it doesn't exist
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # Check model configuration
            config = AutoConfig.from_pretrained(self.model_name)
            logger.info(f"Model config type: {type(config).__name__}")
            
            # Try different model loading strategies
            loading_strategies = [
                ("AutoModel with trust_remote_code", lambda: AutoModel.from_pretrained(
                    self.model_name, trust_remote_code=True
                )),
                ("AutoModelForCausalLM with trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
                    self.model_name, trust_remote_code=True
                )),
                ("AutoModelForSeq2SeqLM with trust_remote_code", lambda: AutoModelForSeq2SeqLM.from_pretrained(
                    self.model_name, trust_remote_code=True
                )),
                ("AutoModel without trust_remote_code", lambda: AutoModel.from_pretrained(
                    self.model_name
                )),
                ("AutoModelForCausalLM without trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
                    self.model_name
                )),
            ]
            
            for strategy_name, strategy_func in loading_strategies:
                try:
                    logger.info(f"Trying: {strategy_name}")
                    self.model = strategy_func()
                    self.model_type = type(self.model).__name__
                    logger.info(f"βœ… Successfully loaded model with {strategy_name}")
                    logger.info(f"Model type: {self.model_type}")
                    break
                except Exception as e:
                    logger.warning(f"❌ {strategy_name} failed: {str(e)[:100]}...")
                    continue
            
            if self.model is None:
                raise RuntimeError("❌ Failed to load model with any strategy")
            
            # Set to evaluation mode
            self.model.eval()
            logger.info("βœ… Model set to evaluation mode")
            
        except Exception as e:
            logger.error(f"❌ Critical error in model loading: {e}")
            raise
    
    def debug(self, code: str) -> str:
        """Debug the provided code."""
        if not code or not code.strip():
            return "❌ Please provide some code to debug."
        
        try:
            # Prepare input with a clear prompt
            prompt = f"### Task: Debug and fix the following Python code\n\n### Input Code:\n{code}\n\n### Fixed Code:"
            
            # Tokenize input
            inputs = self.tokenizer(
                prompt, 
                return_tensors="pt", 
                max_length=512, 
                truncation=True,
                padding=True
            )
            
            logger.info(f"Input tokenized. Input IDs shape: {inputs['input_ids'].shape}")
            
            # Generate response based on model type
            with torch.no_grad():
                if hasattr(self.model, 'generate'):
                    logger.info("Using generate method")
                    
                    # Generation parameters
                    generation_kwargs = {
                        **inputs,
                        'max_new_tokens': 256,
                        'num_beams': 3,
                        'early_stopping': True,
                        'do_sample': True,
                        'temperature': 0.7,
                        'pad_token_id': self.tokenizer.pad_token_id,
                    }
                    
                    # Add eos_token_id if available
                    if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
                        generation_kwargs['eos_token_id'] = self.tokenizer.eos_token_id
                    
                    try:
                        outputs = self.model.generate(**generation_kwargs)
                        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                        
                        # Clean up the response
                        if prompt in response:
                            response = response.replace(prompt, "").strip()
                        
                        # If response is empty or too short, provide a fallback
                        if not response or len(response.strip()) < 10:
                            response = f"The model processed your code but didn't generate a clear fix. Original code:\n\n{code}\n\nπŸ’‘ Tip: Check for common issues like indentation, syntax errors, or missing imports."
                        
                    except Exception as gen_error:
                        logger.error(f"Generation error: {gen_error}")
                        response = f"❌ Error during code generation: {str(gen_error)}\n\nOriginal code:\n{code}"
                
                else:
                    # For encoder-only models
                    logger.info("Model doesn't support generation. Using encoder-only approach.")
                    outputs = self.model(**inputs)
                    
                    # This is a simplified approach - you might need to add a classification head
                    # or use a different strategy based on your specific model
                    response = f"⚠️  This model type ({self.model_type}) doesn't support text generation directly.\n\nOriginal code:\n{code}\n\nπŸ’‘ Consider using a generative model (T5, GPT, BART) for code debugging tasks."
            
            return response
            
        except Exception as e:
            logger.error(f"Error in debug method: {e}")
            return f"❌ Error during debugging: {str(e)}\n\nOriginal code:\n{code}"