Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import logging | |
import json | |
import re | |
from typing import List, Dict | |
from datetime import datetime | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Model configuration | |
MODEL_NAME = "wizcodes12/snaxfix-model" | |
FALLBACK_MODEL = "google/flan-t5-small" # Fallback model if main model fails | |
SUPPORTED_LANGUAGES = [ | |
"python", "javascript", "java", "c", "cpp", "csharp", "rust", | |
"php", "html", "css", "sql" | |
] | |
MAX_LENGTH = 512 | |
# Example code snippets with errors for testing | |
EXAMPLE_SNIPPETS = { | |
"python": { | |
"broken": 'def add(a b):\n return a + b', | |
"description": "Missing comma in function parameters" | |
}, | |
"javascript": { | |
"broken": 'function greet() {\n console.log("Hello"\n}', | |
"description": "Missing closing parenthesis and brace" | |
}, | |
"java": { | |
"broken": 'public class Hello {\n public static void main(String[] args) {\n System.out.println("Hello World")\n }\n}', | |
"description": "Missing semicolon" | |
}, | |
"c": { | |
"broken": '#include <stdio.h>\n\nint main() {\n printf("Hello World")\n return 0;\n}', | |
"description": "Missing semicolon" | |
}, | |
"cpp": { | |
"broken": '#include <iostream>\n\nint main() {\n std::cout << "Hello World" << std::endl\n return 0;\n}', | |
"description": "Missing semicolon" | |
}, | |
"csharp": { | |
"broken": 'class Program {\n static void Main(string[] args) {\n Console.WriteLine("Hello World")\n }\n}', | |
"description": "Missing semicolon" | |
}, | |
"rust": { | |
"broken": 'fn main() {\n println!("Hello World")\n}', | |
"description": "Missing semicolon" | |
}, | |
"php": { | |
"broken": '<?php\n echo "Hello World"\n?>', | |
"description": "Missing semicolon" | |
}, | |
"html": { | |
"broken": '<div>\n <p>Hello World</div>\n</p>', | |
"description": "Incorrect tag nesting" | |
}, | |
"css": { | |
"broken": 'body {\n background-color: #ffffff\n}', | |
"description": "Missing semicolon" | |
}, | |
"sql": { | |
"broken": 'SELECT name age FROM users WHERE age > 18', | |
"description": "Missing comma in SELECT clause" | |
} | |
} | |
class SyntaxFixerApp: | |
def __init__(self): | |
logger.info("Initializing SyntaxFixerApp...") | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {self.device}") | |
# Load model and tokenizer with fallback | |
self.model_name_used = None | |
try: | |
logger.info(f"Attempting to load primary model: {MODEL_NAME}") | |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
self.model_name_used = MODEL_NAME | |
logger.info("Primary model and tokenizer loaded successfully") | |
except Exception as e: | |
logger.warning(f"Failed to load primary model: {e}") | |
logger.info(f"Attempting to load fallback model: {FALLBACK_MODEL}") | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(FALLBACK_MODEL) | |
self.model_name_used = FALLBACK_MODEL | |
logger.info("Fallback model and tokenizer loaded successfully") | |
except Exception as fallback_error: | |
logger.error(f"Failed to load fallback model: {fallback_error}") | |
raise Exception(f"Failed to load both primary and fallback models. Primary: {e}, Fallback: {fallback_error}") | |
self.model.to(self.device) | |
self.model.eval() | |
logger.info(f"Using model: {self.model_name_used}") | |
# Initialize history | |
self.history = [] | |
def fix_syntax(self, broken_code: str, language: str) -> str: | |
"""Fix syntax errors in the provided code.""" | |
if not broken_code.strip(): | |
return "Error: Please enter code to fix." | |
if language not in SUPPORTED_LANGUAGES: | |
return f"Error: Language '{language}' is not supported. Choose from: {', '.join(SUPPORTED_LANGUAGES)}" | |
try: | |
# Prepare input - adjust prompt based on model being used | |
if self.model_name_used == FALLBACK_MODEL: | |
# Simplified prompt for fallback model | |
input_text = f"Fix the syntax errors in this {language} code: {broken_code}" | |
else: | |
# Original prompt for specialized model | |
input_text = f"<{language.upper()}> Fix the syntax errors in this {language} code: {broken_code}" | |
inputs = self.tokenizer( | |
input_text, | |
max_length=MAX_LENGTH, | |
truncation=True, | |
padding=True, | |
return_tensors="pt" | |
).to(self.device) | |
# Generate fixed code | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
max_length=MAX_LENGTH, | |
num_beams=4, | |
early_stopping=True, | |
do_sample=False, | |
temperature=0.7, | |
pad_token_id=self.tokenizer.pad_token_id, | |
use_cache=True | |
) | |
fixed_code = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Save to history | |
self.history.append({ | |
"timestamp": datetime.now().isoformat(), | |
"language": language, | |
"broken_code": broken_code, | |
"fixed_code": fixed_code | |
}) | |
return fixed_code | |
except Exception as e: | |
logger.error(f"Error fixing code: {e}") | |
return f"Error: Failed to fix code - {str(e)}" | |
def load_example(self, language: str) -> str: | |
"""Load example broken code for the selected language.""" | |
return EXAMPLE_SNIPPETS.get(language, {}).get("broken", "No example available for this language.") | |
def get_history(self) -> str: | |
"""Return formatted history of fixes.""" | |
if not self.history: | |
return "No history available." | |
history_text = "=== Fix History ===\n" | |
for entry in self.history[-5:]: # Show only last 5 entries to avoid too much text | |
history_text += f"Timestamp: {entry['timestamp']}\n" | |
history_text += f"Language: {entry['language']}\n" | |
history_text += f"Broken Code:\n{entry['broken_code']}\n" | |
history_text += f"Fixed Code:\n{entry['fixed_code']}\n" | |
history_text += "-" * 50 + "\n" | |
return history_text | |
def clear_history(self) -> str: | |
"""Clear the history of fixes.""" | |
self.history = [] | |
return "History cleared." | |
def create_gradio_interface(): | |
"""Create and return the Gradio interface.""" | |
app = SyntaxFixerApp() | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# SnaxFix: Advanced Syntax Error Fixer") | |
gr.Markdown("Fix syntax errors in code across multiple programming languages using AI models.") | |
# Show which model is being used | |
gr.Markdown(f"**Currently using model:** `{app.model_name_used}`") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
language_dropdown = gr.Dropdown( | |
choices=SUPPORTED_LANGUAGES, | |
label="Select Programming Language", | |
value="python" | |
) | |
code_input = gr.Code( | |
label="Enter Code with Syntax Errors", | |
lines=10, | |
language="python" | |
) | |
with gr.Row(): | |
fix_button = gr.Button("Fix Syntax", variant="primary") | |
example_button = gr.Button("Load Example", variant="secondary") | |
clear_button = gr.Button("Clear Input", variant="secondary") | |
with gr.Column(scale=2): | |
code_output = gr.Code( | |
label="Fixed Code", | |
lines=10, | |
language="python" | |
) | |
with gr.Accordion("History of Fixes", open=False): | |
history_output = gr.Textbox(label="Fix History", lines=10) | |
with gr.Row(): | |
refresh_history_button = gr.Button("Refresh History") | |
clear_history_button = gr.Button("Clear History") | |
with gr.Accordion("About & License", open=False): | |
gr.Markdown(""" | |
**About SnaxFix** | |
SnaxFix is an AI-powered tool for fixing syntax errors in multiple programming languages, built with `google/flan-t5-base` and fine-tuned by wizcodes12. | |
**MIT License** | |
This project is licensed under the MIT License. See the [LICENSE](https://github.com/wizcodes12/snaxfix-model/blob/main/LICENSE) file for details. | |
""") | |
# Event handlers | |
def update_code_language(language): | |
return gr.update(language=language) | |
def fix_and_update_history(code, language): | |
"""Fix code and return both fixed code and updated history.""" | |
fixed = app.fix_syntax(code, language) | |
history = app.get_history() | |
return fixed, history | |
# Main fix button - fixes code and updates history | |
fix_button.click( | |
fn=fix_and_update_history, | |
inputs=[code_input, language_dropdown], | |
outputs=[code_output, history_output] | |
) | |
# Load example button | |
example_button.click( | |
fn=app.load_example, | |
inputs=language_dropdown, | |
outputs=code_input | |
) | |
# Clear input button | |
clear_button.click( | |
fn=lambda: "", | |
inputs=None, | |
outputs=code_input | |
) | |
# Language dropdown change - updates code editor language | |
language_dropdown.change( | |
fn=update_code_language, | |
inputs=language_dropdown, | |
outputs=code_input | |
) | |
# Refresh history button | |
refresh_history_button.click( | |
fn=app.get_history, | |
inputs=None, | |
outputs=history_output | |
) | |
# Clear history button | |
clear_history_button.click( | |
fn=app.clear_history, | |
inputs=None, | |
outputs=history_output | |
) | |
return demo | |
if __name__ == "__main__": | |
logger.info("Starting Gradio application...") | |
demo = create_gradio_interface() | |
try: | |
demo.launch() | |
except Exception as e: | |
logger.error(f"Failed to launch Gradio app: {e}") | |
raise |