Girinath11 commited on
Commit
4707188
·
verified ·
1 Parent(s): 835c4ff

Update model_wrapper.py

Browse files
Files changed (1) hide show
  1. model_wrapper.py +302 -84
model_wrapper.py CHANGED
@@ -1,17 +1,19 @@
1
  import os
2
  import threading
3
-
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModel,
7
  AutoModelForSeq2SeqLM,
8
  AutoModelForCausalLM,
9
- AutoConfig
 
10
  )
11
  import torch
12
  import logging
 
 
13
 
14
- # Set up logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
@@ -21,41 +23,115 @@ class CodeDebuggerWrapper:
21
  self.model = None
22
  self.tokenizer = None
23
  self.model_type = None
 
24
  self._ensure_model()
25
 
 
 
 
 
 
 
 
 
 
 
26
  def _ensure_model(self):
27
- """Load model and tokenizer with fallback strategies."""
28
- logger.info(f"Loading model {self.model_name} ...")
 
29
 
30
  try:
31
- # Load tokenizer first
32
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
33
- logger.info("✅ Tokenizer loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Add pad token if it doesn't exist
36
  if self.tokenizer.pad_token is None:
37
- self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
38
 
39
- # Check model configuration
40
- config = AutoConfig.from_pretrained(self.model_name)
41
- logger.info(f"Model config type: {type(config).__name__}")
42
 
43
- # Try different model loading strategies
44
- loading_strategies = [
45
- ("AutoModel with trust_remote_code", lambda: AutoModel.from_pretrained(
46
- self.model_name, trust_remote_code=True
 
 
 
 
 
47
  )),
48
- ("AutoModelForCausalLM with trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
49
- self.model_name, trust_remote_code=True
 
 
 
 
50
  )),
51
- ("AutoModelForSeq2SeqLM with trust_remote_code", lambda: AutoModelForSeq2SeqLM.from_pretrained(
52
- self.model_name, trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )),
54
- ("AutoModel without trust_remote_code", lambda: AutoModel.from_pretrained(
55
- self.model_name
 
 
 
 
56
  )),
57
- ("AutoModelForCausalLM without trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
58
- self.model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )),
60
  ]
61
 
@@ -66,90 +142,232 @@ class CodeDebuggerWrapper:
66
  self.model_type = type(self.model).__name__
67
  logger.info(f"✅ Successfully loaded model with {strategy_name}")
68
  logger.info(f"Model type: {self.model_type}")
69
- break
 
 
 
 
 
 
70
  except Exception as e:
71
- logger.warning(f"❌ {strategy_name} failed: {str(e)[:100]}...")
72
- continue
73
-
74
- if self.model is None:
75
- raise RuntimeError("❌ Failed to load model with any strategy")
76
 
77
- # Set to evaluation mode
78
- self.model.eval()
79
- logger.info("✅ Model set to evaluation mode")
80
 
81
  except Exception as e:
82
  logger.error(f"❌ Critical error in model loading: {e}")
 
83
  raise
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def debug(self, code: str) -> str:
86
- """Debug the provided code."""
87
  if not code or not code.strip():
88
  return "❌ Please provide some code to debug."
89
 
90
  try:
91
- # Prepare input with a clear prompt
92
- prompt = f"### Task: Debug and fix the following Python code\n\n### Input Code:\n{code}\n\n### Fixed Code:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # Tokenize input
95
  inputs = self.tokenizer(
96
- prompt,
97
- return_tensors="pt",
98
- max_length=512,
99
  truncation=True,
100
  padding=True
101
  )
102
 
103
- logger.info(f"Input tokenized. Input IDs shape: {inputs['input_ids'].shape}")
104
-
105
- # Generate response based on model type
106
  with torch.no_grad():
107
  if hasattr(self.model, 'generate'):
108
- logger.info("Using generate method")
109
-
110
- # Generation parameters
111
- generation_kwargs = {
112
  **inputs,
113
- 'max_new_tokens': 256,
114
- 'num_beams': 3,
115
- 'early_stopping': True,
116
- 'do_sample': True,
117
- 'temperature': 0.7,
118
- 'pad_token_id': self.tokenizer.pad_token_id,
119
- }
 
 
 
120
 
121
- # Add eos_token_id if available
122
- if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
123
- generation_kwargs['eos_token_id'] = self.tokenizer.eos_token_id
124
 
125
- try:
126
- outputs = self.model.generate(**generation_kwargs)
127
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
128
-
129
- # Clean up the response
130
- if prompt in response:
131
- response = response.replace(prompt, "").strip()
132
-
133
- # If response is empty or too short, provide a fallback
134
- if not response or len(response.strip()) < 10:
135
- response = f"The model processed your code but didn't generate a clear fix. Original code:\n\n{code}\n\n💡 Tip: Check for common issues like indentation, syntax errors, or missing imports."
136
-
137
- except Exception as gen_error:
138
- logger.error(f"Generation error: {gen_error}")
139
- response = f"❌ Error during code generation: {str(gen_error)}\n\nOriginal code:\n{code}"
140
 
141
  else:
142
- # For encoder-only models
143
- logger.info("Model doesn't support generation. Using encoder-only approach.")
144
- outputs = self.model(**inputs)
145
-
146
- # This is a simplified approach - you might need to add a classification head
147
- # or use a different strategy based on your specific model
148
- response = f"⚠️ This model type ({self.model_type}) doesn't support text generation directly.\n\nOriginal code:\n{code}\n\n💡 Consider using a generative model (T5, GPT, BART) for code debugging tasks."
 
 
 
 
 
 
 
 
 
 
149
 
150
- return response
 
 
 
151
 
152
- except Exception as e:
153
- logger.error(f"Error in debug method: {e}")
154
- return f"❌ Error during debugging: {str(e)}\n\nOriginal code:\n{code}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
 
1
  import os
2
  import threading
3
+ # model_wrapper.py - Enhanced version with better debugging
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModel,
7
  AutoModelForSeq2SeqLM,
8
  AutoModelForCausalLM,
9
+ AutoConfig,
10
+ pipeline
11
  )
12
  import torch
13
  import logging
14
+ import os
15
+ import traceback
16
 
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
 
23
  self.model = None
24
  self.tokenizer = None
25
  self.model_type = None
26
+ self.pipeline = None
27
  self._ensure_model()
28
 
29
+ def _log_system_info(self):
30
+ """Log system information for debugging."""
31
+ logger.info(f"Python version: {os.sys.version}")
32
+ logger.info(f"PyTorch version: {torch.__version__}")
33
+ try:
34
+ import transformers
35
+ logger.info(f"Transformers version: {transformers.__version__}")
36
+ except:
37
+ logger.warning("Could not get transformers version")
38
+
39
  def _ensure_model(self):
40
+ """Load model and tokenizer with comprehensive fallback strategies."""
41
+ logger.info(f"Starting model loading process for {self.model_name}")
42
+ self._log_system_info()
43
 
44
  try:
45
+ # First, let's inspect the model configuration
46
+ logger.info("Step 1: Inspecting model configuration...")
47
+ config = AutoConfig.from_pretrained(self.model_name)
48
+ logger.info(f"Model architecture: {config.architectures}")
49
+ logger.info(f"Model type: {config.model_type}")
50
+ logger.info(f"Config class: {type(config).__name__}")
51
+
52
+ # Load tokenizer
53
+ logger.info("Step 2: Loading tokenizer...")
54
+ self.tokenizer = AutoTokenizer.from_pretrained(
55
+ self.model_name,
56
+ trust_remote_code=True,
57
+ use_fast=False # Sometimes fast tokenizers cause issues
58
+ )
59
 
60
+ # Add special tokens if missing
61
  if self.tokenizer.pad_token is None:
62
+ if self.tokenizer.eos_token is not None:
63
+ self.tokenizer.pad_token = self.tokenizer.eos_token
64
+ else:
65
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
66
 
67
+ logger.info("✅ Tokenizer loaded successfully")
68
+ logger.info(f"Vocab size: {len(self.tokenizer)}")
69
+ logger.info(f"Special tokens: pad={self.tokenizer.pad_token}, eos={self.tokenizer.eos_token}")
70
 
71
+ # Try loading with pipeline first (often more robust)
72
+ logger.info("Step 3: Attempting pipeline loading...")
73
+ pipeline_strategies = [
74
+ ("text2text-generation", lambda: pipeline(
75
+ "text2text-generation",
76
+ model=self.model_name,
77
+ tokenizer=self.tokenizer,
78
+ trust_remote_code=True,
79
+ device=-1 # CPU
80
  )),
81
+ ("text-generation", lambda: pipeline(
82
+ "text-generation",
83
+ model=self.model_name,
84
+ tokenizer=self.tokenizer,
85
+ trust_remote_code=True,
86
+ device=-1
87
  )),
88
+ ]
89
+
90
+ for pipe_type, pipe_func in pipeline_strategies:
91
+ try:
92
+ logger.info(f"Trying {pipe_type} pipeline...")
93
+ self.pipeline = pipe_func()
94
+ logger.info(f"✅ Successfully loaded {pipe_type} pipeline")
95
+ self.model_type = f"{pipe_type}_pipeline"
96
+ return # Success!
97
+ except Exception as e:
98
+ logger.warning(f"❌ {pipe_type} pipeline failed: {str(e)[:200]}...")
99
+
100
+ # If pipeline fails, try direct model loading
101
+ logger.info("Step 4: Attempting direct model loading...")
102
+ loading_strategies = [
103
+ # Strategy 1: Based on config type, try the most appropriate loader
104
+ ("Config-based AutoModel", lambda: self._load_based_on_config(config)),
105
+
106
+ # Strategy 2: Force different model types with trust_remote_code
107
+ ("AutoModel + trust_remote_code", lambda: AutoModel.from_pretrained(
108
+ self.model_name,
109
+ trust_remote_code=True,
110
+ torch_dtype=torch.float32,
111
+ device_map="cpu"
112
  )),
113
+
114
+ ("AutoModelForCausalLM + trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
115
+ self.model_name,
116
+ trust_remote_code=True,
117
+ torch_dtype=torch.float32,
118
+ device_map="cpu"
119
  )),
120
+
121
+ ("AutoModelForSeq2SeqLM + trust_remote_code + ignore_mismatched", lambda: AutoModelForSeq2SeqLM.from_pretrained(
122
+ self.model_name,
123
+ trust_remote_code=True,
124
+ torch_dtype=torch.float32,
125
+ ignore_mismatched_sizes=True,
126
+ device_map="cpu"
127
+ )),
128
+
129
+ # Strategy 3: Try without trust_remote_code but with other options
130
+ ("AutoModel + low_cpu_mem", lambda: AutoModel.from_pretrained(
131
+ self.model_name,
132
+ torch_dtype=torch.float32,
133
+ low_cpu_mem_usage=True,
134
+ device_map="cpu"
135
  )),
136
  ]
137
 
 
142
  self.model_type = type(self.model).__name__
143
  logger.info(f"✅ Successfully loaded model with {strategy_name}")
144
  logger.info(f"Model type: {self.model_type}")
145
+
146
+ # Set to eval mode
147
+ if hasattr(self.model, 'eval'):
148
+ self.model.eval()
149
+
150
+ return # Success!
151
+
152
  except Exception as e:
153
+ logger.warning(f"❌ {strategy_name} failed: {str(e)[:200]}...")
154
+ logger.debug(f"Full error: {traceback.format_exc()}")
 
 
 
155
 
156
+ # If we get here, all strategies failed
157
+ raise RuntimeError("❌ All model loading strategies failed")
 
158
 
159
  except Exception as e:
160
  logger.error(f"❌ Critical error in model loading: {e}")
161
+ logger.error(f"Full traceback: {traceback.format_exc()}")
162
  raise
163
 
164
+ def _load_based_on_config(self, config):
165
+ """Try to load model based on its configuration type."""
166
+ config_type = type(config).__name__
167
+
168
+ if "T5" in config_type or "Seq2Seq" in config_type:
169
+ return AutoModelForSeq2SeqLM.from_pretrained(
170
+ self.model_name,
171
+ trust_remote_code=True,
172
+ config=config
173
+ )
174
+ elif "GPT" in config_type or "Causal" in config_type:
175
+ return AutoModelForCausalLM.from_pretrained(
176
+ self.model_name,
177
+ trust_remote_code=True,
178
+ config=config
179
+ )
180
+ else:
181
+ return AutoModel.from_pretrained(
182
+ self.model_name,
183
+ trust_remote_code=True,
184
+ config=config
185
+ )
186
+
187
  def debug(self, code: str) -> str:
188
+ """Debug the provided code using the loaded model."""
189
  if not code or not code.strip():
190
  return "❌ Please provide some code to debug."
191
 
192
  try:
193
+ # Use pipeline if available (more robust)
194
+ if self.pipeline is not None:
195
+ return self._debug_with_pipeline(code)
196
+
197
+ # Use direct model if pipeline not available
198
+ if self.model is not None:
199
+ return self._debug_with_model(code)
200
+
201
+ # Fallback: provide manual debugging suggestions
202
+ return self._manual_debug_suggestions(code)
203
+
204
+ except Exception as e:
205
+ logger.error(f"Error during debugging: {e}")
206
+ return f"❌ Error during debugging: {str(e)}\n\n" + self._manual_debug_suggestions(code)
207
+
208
+ def _debug_with_pipeline(self, code: str) -> str:
209
+ """Debug using pipeline."""
210
+ try:
211
+ prompt = f"Fix this Python code:\n\n{code}\n\nFixed code:"
212
+
213
+ if "text2text" in self.model_type:
214
+ result = self.pipeline(prompt, max_length=512, num_beams=3, early_stopping=True)
215
+ return result[0]['generated_text'] if result else self._manual_debug_suggestions(code)
216
+
217
+ elif "text-generation" in self.model_type:
218
+ result = self.pipeline(prompt, max_new_tokens=256, num_return_sequences=1, temperature=0.7)
219
+ generated = result[0]['generated_text'] if result else ""
220
+
221
+ # Clean up the response
222
+ if prompt in generated:
223
+ generated = generated.replace(prompt, "").strip()
224
+
225
+ return generated if generated else self._manual_debug_suggestions(code)
226
+
227
+ except Exception as e:
228
+ logger.error(f"Pipeline debugging failed: {e}")
229
+ return self._manual_debug_suggestions(code)
230
+
231
+ def _debug_with_model(self, code: str) -> str:
232
+ """Debug using direct model."""
233
+ try:
234
+ prompt = f"Debug and fix this Python code:\n\n{code}\n\nFixed code:"
235
 
 
236
  inputs = self.tokenizer(
237
+ prompt,
238
+ return_tensors="pt",
239
+ max_length=512,
240
  truncation=True,
241
  padding=True
242
  )
243
 
 
 
 
244
  with torch.no_grad():
245
  if hasattr(self.model, 'generate'):
246
+ outputs = self.model.generate(
 
 
 
247
  **inputs,
248
+ max_new_tokens=256,
249
+ num_beams=3,
250
+ early_stopping=True,
251
+ pad_token_id=self.tokenizer.pad_token_id,
252
+ eos_token_id=getattr(self.tokenizer, 'eos_token_id', None),
253
+ do_sample=True,
254
+ temperature=0.7
255
+ )
256
+
257
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
258
 
259
+ # Clean response
260
+ if prompt in response:
261
+ response = response.replace(prompt, "").strip()
262
 
263
+ return response if response else self._manual_debug_suggestions(code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  else:
266
+ return f"⚠️ Model type '{self.model_type}' doesn't support generation.\n\n" + self._manual_debug_suggestions(code)
267
+
268
+ except Exception as e:
269
+ logger.error(f"Direct model debugging failed: {e}")
270
+ return self._manual_debug_suggestions(code)
271
+
272
+ def _manual_debug_suggestions(self, code: str) -> str:
273
+ """Provide manual debugging suggestions when AI model fails."""
274
+ suggestions = []
275
+
276
+ # Check for common Python syntax errors
277
+ lines = code.split('\n')
278
+
279
+ for i, line in enumerate(lines, 1):
280
+ line_stripped = line.strip()
281
+ if not line_stripped or line_stripped.startswith('#'):
282
+ continue
283
 
284
+ # Check for missing colons
285
+ if any(keyword in line_stripped for keyword in ['if ', 'for ', 'while ', 'def ', 'class ', 'try:', 'except', 'else', 'elif']):
286
+ if not line_stripped.endswith(':') and not line_stripped.endswith(':\\'):
287
+ suggestions.append(f"Line {i}: Missing colon (:) at end of statement")
288
 
289
+ # Check for obvious indentation issues
290
+ if i > 1 and line_stripped and not line.startswith(' ') and not line.startswith('\t'):
291
+ prev_line = lines[i-2].strip() if i > 1 else ""
292
+ if prev_line.endswith(':'):
293
+ suggestions.append(f"Line {i}: Possible indentation error - code after ':' should be indented")
294
+
295
+ # Check for common runtime errors
296
+ if 'len(' in code and '[]' in code:
297
+ suggestions.append("⚠️ Potential division by zero: Check for empty lists before using len()")
298
+
299
+ if '/0' in code or '/ 0' in code:
300
+ suggestions.append("⚠️ Division by zero detected")
301
+
302
+ # Create response
303
+ result = f"🔧 **Manual Debug Analysis for:**\n```python\n{code}\n```\n\n"
304
+
305
+ if suggestions:
306
+ result += "**Issues Found:**\n"
307
+ for suggestion in suggestions:
308
+ result += f"• {suggestion}\n"
309
+ else:
310
+ result += "**No obvious syntax errors detected.**\n"
311
+
312
+ result += "\n**General Tips:**\n"
313
+ result += "• Check for missing colons (:) after if/for/def statements\n"
314
+ result += "• Verify proper indentation (4 spaces per level)\n"
315
+ result += "• Ensure all parentheses, brackets, and quotes are balanced\n"
316
+ result += "• Check for typos in variable and function names\n"
317
+ result += "• Make sure all required imports are included\n"
318
+
319
+ return result
320
+
321
+
322
+ # Alternative lightweight debugger if the main one fails completely
323
+ class FallbackDebugger:
324
+ def __init__(self):
325
+ self.model = None
326
+ self.tokenizer = None
327
+ logger.info("Using fallback debugger - AI model unavailable")
328
+
329
+ def debug(self, code: str) -> str:
330
+ """Simple rule-based debugging."""
331
+ if not code or not code.strip():
332
+ return "❌ Please provide some code to debug."
333
+
334
+ issues = []
335
+ lines = code.split('\n')
336
+
337
+ # Basic syntax checking
338
+ for i, line in enumerate(lines, 1):
339
+ stripped = line.strip()
340
+ if not stripped or stripped.startswith('#'):
341
+ continue
342
+
343
+ # Missing colons
344
+ control_words = ['if ', 'elif ', 'else', 'for ', 'while ', 'def ', 'class ', 'try', 'except', 'finally']
345
+ if any(word in stripped for word in control_words):
346
+ if not stripped.endswith(':'):
347
+ issues.append(f"Line {i}: Missing colon (:)")
348
+
349
+ # Indentation after colon
350
+ if i < len(lines) and stripped.endswith(':'):
351
+ next_line = lines[i] if i < len(lines) else ""
352
+ if next_line.strip() and not next_line.startswith((' ', '\t')):
353
+ issues.append(f"Line {i+1}: Should be indented after ':'")
354
+
355
+ # Generate response
356
+ result = f"🔧 **Code Analysis** (AI Model Unavailable)\n\n"
357
+ result += f"```python\n{code}\n```\n\n"
358
+
359
+ if issues:
360
+ result += "**Potential Issues:**\n"
361
+ for issue in issues:
362
+ result += f"• {issue}\n"
363
+ else:
364
+ result += "**No obvious syntax errors found.**\n"
365
+
366
+ result += "\n**Common Debugging Steps:**\n"
367
+ result += "1. Run the code to see specific error messages\n"
368
+ result += "2. Check syntax: colons, indentation, parentheses\n"
369
+ result += "3. Verify variable names and imports\n"
370
+ result += "4. Use print() statements to debug logic\n"
371
+
372
+ return result
373