Spaces:
Running
Running
import os | |
import threading | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
AutoModelForSeq2SeqLM, | |
AutoModelForCausalLM, | |
AutoConfig | |
) | |
import torch | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class CodeDebuggerWrapper: | |
def __init__(self, model_name="Girinath11/aiml_code_debug_model"): | |
self.model_name = model_name | |
self.model = None | |
self.tokenizer = None | |
self.model_type = None | |
self._ensure_model() | |
def _ensure_model(self): | |
"""Load model and tokenizer with fallback strategies.""" | |
logger.info(f"Loading model {self.model_name} ...") | |
try: | |
# Load tokenizer first | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
logger.info("β Tokenizer loaded successfully") | |
# Add pad token if it doesn't exist | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Check model configuration | |
config = AutoConfig.from_pretrained(self.model_name) | |
logger.info(f"Model config type: {type(config).__name__}") | |
# Try different model loading strategies | |
loading_strategies = [ | |
("AutoModel with trust_remote_code", lambda: AutoModel.from_pretrained( | |
self.model_name, trust_remote_code=True | |
)), | |
("AutoModelForCausalLM with trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained( | |
self.model_name, trust_remote_code=True | |
)), | |
("AutoModelForSeq2SeqLM with trust_remote_code", lambda: AutoModelForSeq2SeqLM.from_pretrained( | |
self.model_name, trust_remote_code=True | |
)), | |
("AutoModel without trust_remote_code", lambda: AutoModel.from_pretrained( | |
self.model_name | |
)), | |
("AutoModelForCausalLM without trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained( | |
self.model_name | |
)), | |
] | |
for strategy_name, strategy_func in loading_strategies: | |
try: | |
logger.info(f"Trying: {strategy_name}") | |
self.model = strategy_func() | |
self.model_type = type(self.model).__name__ | |
logger.info(f"β Successfully loaded model with {strategy_name}") | |
logger.info(f"Model type: {self.model_type}") | |
break | |
except Exception as e: | |
logger.warning(f"β {strategy_name} failed: {str(e)[:100]}...") | |
continue | |
if self.model is None: | |
raise RuntimeError("β Failed to load model with any strategy") | |
# Set to evaluation mode | |
self.model.eval() | |
logger.info("β Model set to evaluation mode") | |
except Exception as e: | |
logger.error(f"β Critical error in model loading: {e}") | |
raise | |
def debug(self, code: str) -> str: | |
"""Debug the provided code.""" | |
if not code or not code.strip(): | |
return "β Please provide some code to debug." | |
try: | |
# Prepare input with a clear prompt | |
prompt = f"### Task: Debug and fix the following Python code\n\n### Input Code:\n{code}\n\n### Fixed Code:" | |
# Tokenize input | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
) | |
logger.info(f"Input tokenized. Input IDs shape: {inputs['input_ids'].shape}") | |
# Generate response based on model type | |
with torch.no_grad(): | |
if hasattr(self.model, 'generate'): | |
logger.info("Using generate method") | |
# Generation parameters | |
generation_kwargs = { | |
**inputs, | |
'max_new_tokens': 256, | |
'num_beams': 3, | |
'early_stopping': True, | |
'do_sample': True, | |
'temperature': 0.7, | |
'pad_token_id': self.tokenizer.pad_token_id, | |
} | |
# Add eos_token_id if available | |
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: | |
generation_kwargs['eos_token_id'] = self.tokenizer.eos_token_id | |
try: | |
outputs = self.model.generate(**generation_kwargs) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up the response | |
if prompt in response: | |
response = response.replace(prompt, "").strip() | |
# If response is empty or too short, provide a fallback | |
if not response or len(response.strip()) < 10: | |
response = f"The model processed your code but didn't generate a clear fix. Original code:\n\n{code}\n\nπ‘ Tip: Check for common issues like indentation, syntax errors, or missing imports." | |
except Exception as gen_error: | |
logger.error(f"Generation error: {gen_error}") | |
response = f"β Error during code generation: {str(gen_error)}\n\nOriginal code:\n{code}" | |
else: | |
# For encoder-only models | |
logger.info("Model doesn't support generation. Using encoder-only approach.") | |
outputs = self.model(**inputs) | |
# This is a simplified approach - you might need to add a classification head | |
# or use a different strategy based on your specific model | |
response = f"β οΈ This model type ({self.model_type}) doesn't support text generation directly.\n\nOriginal code:\n{code}\n\nπ‘ Consider using a generative model (T5, GPT, BART) for code debugging tasks." | |
return response | |
except Exception as e: | |
logger.error(f"Error in debug method: {e}") | |
return f"β Error during debugging: {str(e)}\n\nOriginal code:\n{code}" | |