Spaces:
Running
Running
File size: 6,864 Bytes
87ce049 bde5911 87ce049 bde5911 87ce049 bde5911 87ce049 bde5911 87ce049 bde5911 87ce049 bde5911 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import threading
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForSeq2SeqLM,
AutoModelForCausalLM,
AutoConfig
)
import torch
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CodeDebuggerWrapper:
def __init__(self, model_name="Girinath11/aiml_code_debug_model"):
self.model_name = model_name
self.model = None
self.tokenizer = None
self.model_type = None
self._ensure_model()
def _ensure_model(self):
"""Load model and tokenizer with fallback strategies."""
logger.info(f"Loading model {self.model_name} ...")
try:
# Load tokenizer first
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
logger.info("β
Tokenizer loaded successfully")
# Add pad token if it doesn't exist
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Check model configuration
config = AutoConfig.from_pretrained(self.model_name)
logger.info(f"Model config type: {type(config).__name__}")
# Try different model loading strategies
loading_strategies = [
("AutoModel with trust_remote_code", lambda: AutoModel.from_pretrained(
self.model_name, trust_remote_code=True
)),
("AutoModelForCausalLM with trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
self.model_name, trust_remote_code=True
)),
("AutoModelForSeq2SeqLM with trust_remote_code", lambda: AutoModelForSeq2SeqLM.from_pretrained(
self.model_name, trust_remote_code=True
)),
("AutoModel without trust_remote_code", lambda: AutoModel.from_pretrained(
self.model_name
)),
("AutoModelForCausalLM without trust_remote_code", lambda: AutoModelForCausalLM.from_pretrained(
self.model_name
)),
]
for strategy_name, strategy_func in loading_strategies:
try:
logger.info(f"Trying: {strategy_name}")
self.model = strategy_func()
self.model_type = type(self.model).__name__
logger.info(f"β
Successfully loaded model with {strategy_name}")
logger.info(f"Model type: {self.model_type}")
break
except Exception as e:
logger.warning(f"β {strategy_name} failed: {str(e)[:100]}...")
continue
if self.model is None:
raise RuntimeError("β Failed to load model with any strategy")
# Set to evaluation mode
self.model.eval()
logger.info("β
Model set to evaluation mode")
except Exception as e:
logger.error(f"β Critical error in model loading: {e}")
raise
def debug(self, code: str) -> str:
"""Debug the provided code."""
if not code or not code.strip():
return "β Please provide some code to debug."
try:
# Prepare input with a clear prompt
prompt = f"### Task: Debug and fix the following Python code\n\n### Input Code:\n{code}\n\n### Fixed Code:"
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
logger.info(f"Input tokenized. Input IDs shape: {inputs['input_ids'].shape}")
# Generate response based on model type
with torch.no_grad():
if hasattr(self.model, 'generate'):
logger.info("Using generate method")
# Generation parameters
generation_kwargs = {
**inputs,
'max_new_tokens': 256,
'num_beams': 3,
'early_stopping': True,
'do_sample': True,
'temperature': 0.7,
'pad_token_id': self.tokenizer.pad_token_id,
}
# Add eos_token_id if available
if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None:
generation_kwargs['eos_token_id'] = self.tokenizer.eos_token_id
try:
outputs = self.model.generate(**generation_kwargs)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the response
if prompt in response:
response = response.replace(prompt, "").strip()
# If response is empty or too short, provide a fallback
if not response or len(response.strip()) < 10:
response = f"The model processed your code but didn't generate a clear fix. Original code:\n\n{code}\n\nπ‘ Tip: Check for common issues like indentation, syntax errors, or missing imports."
except Exception as gen_error:
logger.error(f"Generation error: {gen_error}")
response = f"β Error during code generation: {str(gen_error)}\n\nOriginal code:\n{code}"
else:
# For encoder-only models
logger.info("Model doesn't support generation. Using encoder-only approach.")
outputs = self.model(**inputs)
# This is a simplified approach - you might need to add a classification head
# or use a different strategy based on your specific model
response = f"β οΈ This model type ({self.model_type}) doesn't support text generation directly.\n\nOriginal code:\n{code}\n\nπ‘ Consider using a generative model (T5, GPT, BART) for code debugging tasks."
return response
except Exception as e:
logger.error(f"Error in debug method: {e}")
return f"β Error during debugging: {str(e)}\n\nOriginal code:\n{code}"
|