jina-code-debugger / model_wrapper.py
Girinath11's picture
Update model_wrapper.py
bde5911 verified
raw
history blame
6.86 kB
import os
import threading
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
AutoConfig
)
import torch
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CodeDebuggerWrapper:
def __init__(self, model_name="Girinath11/aiml_code_debug_model"):
self.model_name = model_name
self.model = None
self.tokenizer = None
self.model_type = None
self._ensure_model()
def _ensure_model(self):
"""Load model and tokenizer with fallback strategies."""
logger.info(f"Loading model {self.model_name} ...")
try:
# Load tokenizer first
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
logger.info("βœ… Tokenizer loaded successfully")
# Add pad token if it doesn't exist
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Check model configuration
config = AutoConfig.from_pretrained(self.model_name)
logger.info(f"Model config type: {type(config).__name__}")
# Try different model loading strategies
loading_strategies = [
("AutoModel with trust_remote_code", lambda: AutoModel.from_pretrained(
self.model_name, trust_remote_code=True
)),
("AutoModelForCausalLM with trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
self.model_name, trust_remote_code=True
)),
("AutoModelForSeq2SeqLM with trust_remote_code", lambda: AutoModelForSeq2SeqLM.from_pretrained(
self.model_name, trust_remote_code=True
)),
("AutoModel without trust_remote_code", lambda: AutoModel.from_pretrained(
self.model_name
)),
("AutoModelForCausalLM without trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
self.model_name
)),
]
for strategy_name, strategy_func in loading_strategies:
try:
logger.info(f"Trying: {strategy_name}")
self.model = strategy_func()
self.model_type = type(self.model).__name__
logger.info(f"βœ… Successfully loaded model with {strategy_name}")
logger.info(f"Model type: {self.model_type}")
break
except Exception as e:
logger.warning(f"❌ {strategy_name} failed: {str(e)[:100]}...")
continue
if self.model is None:
raise RuntimeError("❌ Failed to load model with any strategy")
# Set to evaluation mode
self.model.eval()
logger.info("βœ… Model set to evaluation mode")
except Exception as e:
logger.error(f"❌ Critical error in model loading: {e}")
raise
def debug(self, code: str) -> str:
"""Debug the provided code."""
if not code or not code.strip():
return "❌ Please provide some code to debug."
try:
# Prepare input with a clear prompt
prompt = f"### Task: Debug and fix the following Python code\n\n### Input Code:\n{code}\n\n### Fixed Code:"
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
logger.info(f"Input tokenized. Input IDs shape: {inputs['input_ids'].shape}")
# Generate response based on model type
with torch.no_grad():
if hasattr(self.model, 'generate'):
logger.info("Using generate method")
# Generation parameters
generation_kwargs = {
**inputs,
'max_new_tokens': 256,
'num_beams': 3,
'early_stopping': True,
'do_sample': True,
'temperature': 0.7,
'pad_token_id': self.tokenizer.pad_token_id,
}
# Add eos_token_id if available
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
generation_kwargs['eos_token_id'] = self.tokenizer.eos_token_id
try:
outputs = self.model.generate(**generation_kwargs)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the response
if prompt in response:
response = response.replace(prompt, "").strip()
# If response is empty or too short, provide a fallback
if not response or len(response.strip()) < 10:
response = f"The model processed your code but didn't generate a clear fix. Original code:\n\n{code}\n\nπŸ’‘ Tip: Check for common issues like indentation, syntax errors, or missing imports."
except Exception as gen_error:
logger.error(f"Generation error: {gen_error}")
response = f"❌ Error during code generation: {str(gen_error)}\n\nOriginal code:\n{code}"
else:
# For encoder-only models
logger.info("Model doesn't support generation. Using encoder-only approach.")
outputs = self.model(**inputs)
# This is a simplified approach - you might need to add a classification head
# or use a different strategy based on your specific model
response = f"⚠️ This model type ({self.model_type}) doesn't support text generation directly.\n\nOriginal code:\n{code}\n\nπŸ’‘ Consider using a generative model (T5, GPT, BART) for code debugging tasks."
return response
except Exception as e:
logger.error(f"Error in debug method: {e}")
return f"❌ Error during debugging: {str(e)}\n\nOriginal code:\n{code}"