Girinath11 commited on
Commit
bde5911
Β·
verified Β·
1 Parent(s): 87ce049

Update model_wrapper.py

Browse files
Files changed (1) hide show
  1. model_wrapper.py +146 -29
model_wrapper.py CHANGED
@@ -1,38 +1,155 @@
1
  import os
2
  import threading
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class CodeDebuggerWrapper:
5
- """
6
- Simple wrapper that loads the same HF model and exposes debug(code: str) -> str
7
- This is used by app.py (Gradio).
8
- """
9
- def __init__(self, model_name: str = "Girinath11/aiml_code_debug_model"):
10
  self.model_name = model_name
11
- self._lock = threading.Lock()
12
- self.tokenizer = None
13
  self.model = None
14
- self.max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", "256"))
 
15
  self._ensure_model()
16
-
17
  def _ensure_model(self):
18
- # allow skipping in environments where you don't want to download weights
19
- skip = os.environ.get("SKIP_MODEL_LOAD", "0") == "1"
20
- if skip:
21
- print("SKIP_MODEL_LOAD=1 -> not loading model.")
22
- return
23
-
24
- if self.model is None or self.tokenizer is None:
25
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
26
- with self._lock:
27
- if self.model is None or self.tokenizer is None:
28
- print(f"Loading model {self.model_name} ...")
29
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
30
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
31
- print("Model loaded.")
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def debug(self, code: str) -> str:
34
- if self.model is None or self.tokenizer is None:
35
- return "Model not loaded. Set SKIP_MODEL_LOAD=0 and ensure HF token is available if model is private."
36
- inputs = self.tokenizer(code, return_tensors="pt", padding=True, truncation=True)
37
- outputs = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
38
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import threading
3
 
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModel,
7
+ AutoModelForSeq2SeqLM,
8
+ AutoModelForCausalLM,
9
+ AutoConfig
10
+ )
11
+ import torch
12
+ import logging
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
  class CodeDebuggerWrapper:
19
+ def __init__(self, model_name="Girinath11/aiml_code_debug_model"):
 
 
 
 
20
  self.model_name = model_name
 
 
21
  self.model = None
22
+ self.tokenizer = None
23
+ self.model_type = None
24
  self._ensure_model()
25
+
26
  def _ensure_model(self):
27
+ """Load model and tokenizer with fallback strategies."""
28
+ logger.info(f"Loading model {self.model_name} ...")
29
+
30
+ try:
31
+ # Load tokenizer first
32
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
33
+ logger.info("βœ… Tokenizer loaded successfully")
34
+
35
+ # Add pad token if it doesn't exist
36
+ if self.tokenizer.pad_token is None:
37
+ self.tokenizer.pad_token = self.tokenizer.eos_token
38
+
39
+ # Check model configuration
40
+ config = AutoConfig.from_pretrained(self.model_name)
41
+ logger.info(f"Model config type: {type(config).__name__}")
42
+
43
+ # Try different model loading strategies
44
+ loading_strategies = [
45
+ ("AutoModel with trust_remote_code", lambda: AutoModel.from_pretrained(
46
+ self.model_name, trust_remote_code=True
47
+ )),
48
+ ("AutoModelForCausalLM with trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
49
+ self.model_name, trust_remote_code=True
50
+ )),
51
+ ("AutoModelForSeq2SeqLM with trust_remote_code", lambda: AutoModelForSeq2SeqLM.from_pretrained(
52
+ self.model_name, trust_remote_code=True
53
+ )),
54
+ ("AutoModel without trust_remote_code", lambda: AutoModel.from_pretrained(
55
+ self.model_name
56
+ )),
57
+ ("AutoModelForCausalLM without trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
58
+ self.model_name
59
+ )),
60
+ ]
61
+
62
+ for strategy_name, strategy_func in loading_strategies:
63
+ try:
64
+ logger.info(f"Trying: {strategy_name}")
65
+ self.model = strategy_func()
66
+ self.model_type = type(self.model).__name__
67
+ logger.info(f"βœ… Successfully loaded model with {strategy_name}")
68
+ logger.info(f"Model type: {self.model_type}")
69
+ break
70
+ except Exception as e:
71
+ logger.warning(f"❌ {strategy_name} failed: {str(e)[:100]}...")
72
+ continue
73
+
74
+ if self.model is None:
75
+ raise RuntimeError("❌ Failed to load model with any strategy")
76
+
77
+ # Set to evaluation mode
78
+ self.model.eval()
79
+ logger.info("βœ… Model set to evaluation mode")
80
+
81
+ except Exception as e:
82
+ logger.error(f"❌ Critical error in model loading: {e}")
83
+ raise
84
+
85
  def debug(self, code: str) -> str:
86
+ """Debug the provided code."""
87
+ if not code or not code.strip():
88
+ return "❌ Please provide some code to debug."
89
+
90
+ try:
91
+ # Prepare input with a clear prompt
92
+ prompt = f"### Task: Debug and fix the following Python code\n\n### Input Code:\n{code}\n\n### Fixed Code:"
93
+
94
+ # Tokenize input
95
+ inputs = self.tokenizer(
96
+ prompt,
97
+ return_tensors="pt",
98
+ max_length=512,
99
+ truncation=True,
100
+ padding=True
101
+ )
102
+
103
+ logger.info(f"Input tokenized. Input IDs shape: {inputs['input_ids'].shape}")
104
+
105
+ # Generate response based on model type
106
+ with torch.no_grad():
107
+ if hasattr(self.model, 'generate'):
108
+ logger.info("Using generate method")
109
+
110
+ # Generation parameters
111
+ generation_kwargs = {
112
+ **inputs,
113
+ 'max_new_tokens': 256,
114
+ 'num_beams': 3,
115
+ 'early_stopping': True,
116
+ 'do_sample': True,
117
+ 'temperature': 0.7,
118
+ 'pad_token_id': self.tokenizer.pad_token_id,
119
+ }
120
+
121
+ # Add eos_token_id if available
122
+ if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
123
+ generation_kwargs['eos_token_id'] = self.tokenizer.eos_token_id
124
+
125
+ try:
126
+ outputs = self.model.generate(**generation_kwargs)
127
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
128
+
129
+ # Clean up the response
130
+ if prompt in response:
131
+ response = response.replace(prompt, "").strip()
132
+
133
+ # If response is empty or too short, provide a fallback
134
+ if not response or len(response.strip()) < 10:
135
+ response = f"The model processed your code but didn't generate a clear fix. Original code:\n\n{code}\n\nπŸ’‘ Tip: Check for common issues like indentation, syntax errors, or missing imports."
136
+
137
+ except Exception as gen_error:
138
+ logger.error(f"Generation error: {gen_error}")
139
+ response = f"❌ Error during code generation: {str(gen_error)}\n\nOriginal code:\n{code}"
140
+
141
+ else:
142
+ # For encoder-only models
143
+ logger.info("Model doesn't support generation. Using encoder-only approach.")
144
+ outputs = self.model(**inputs)
145
+
146
+ # This is a simplified approach - you might need to add a classification head
147
+ # or use a different strategy based on your specific model
148
+ response = f"⚠️ This model type ({self.model_type}) doesn't support text generation directly.\n\nOriginal code:\n{code}\n\nπŸ’‘ Consider using a generative model (T5, GPT, BART) for code debugging tasks."
149
+
150
+ return response
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error in debug method: {e}")
154
+ return f"❌ Error during debugging: {str(e)}\n\nOriginal code:\n{code}"
155
+