import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
import re

# Set up logging
logging.basicConfig(
    filename="app.log",
    level=logging.INFO,
    format="%(asctime)s:%(levelname)s:%(message)s"
)

# Model and tokenizer loading function with caching
def load_model():
    """
    Loads and caches the pre-trained language model and tokenizer.
    Returns:
        model: Pre-trained language model.
        tokenizer: Tokenizer for the model.
    """
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model_path = "Canstralian/pentest_ai"  # Replace with the actual path if different
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            device_map={"": device},  # This will specify CPU or GPU explicitly
            load_in_8bit=False,  # Disabled for stability
            trust_remote_code=True,
        )
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        logging.info("Model and tokenizer loaded successfully.")
        return model, tokenizer
    except Exception as e:
        logging.error(f"Error loading model: {e}")
        return None, None

def sanitize_input(text):
    """
    Sanitizes and validates user input text to prevent injection or formatting issues.
    
    Args:
        text (str): User input text.
    Returns:
        str: Sanitized text.
    """
    if not isinstance(text, str):
        raise ValueError("Input must be a string.")
    # Basic sanitization to remove unwanted characters
    sanitized_text = re.sub(r"[^a-zA-Z0-9\s\.,!?]", "", text)
    return sanitized_text.strip()

def generate_text(model, tokenizer, instruction):
    """
    Generates text based on the provided instruction using the loaded model.
    Args:
        model: The language model.
        tokenizer: Tokenizer for encoding/decoding.
        instruction (str): Instruction text for the model.
    Returns:
        str: Generated text response from the model.
    """
    try:
        # Validate and sanitize instruction input
        instruction = sanitize_input(instruction)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        tokens = tokenizer.encode(instruction, return_tensors='pt').to(device)
        generated_tokens = model.generate(
            tokens,
            max_length=1024,
            top_p=1.0,
            temperature=0.5,
            top_k=50
        )
        generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
        logging.info("Text generated successfully.")
        return generated_text
    except Exception as e:
        logging.error(f"Error generating text: {e}")
        return "Error in text generation."

# Gradio Interface Function
def gradio_interface(instruction):
    """
    Interface function for Gradio to interact with the model and generate text.
    """
    # Load the model and tokenizer
    model, tokenizer = load_model()

    if not model or not tokenizer:
        return "Failed to load model or tokenizer. Please check your configuration."

    # Generate the text
    try:
        generated_text = generate_text(model, tokenizer, instruction)
        return generated_text
    except ValueError as ve:
        return f"Invalid input: {ve}"
    except Exception as e:
        logging.error(f"Error during text generation: {e}")
        return "An error occurred. Please try again."

# Create Gradio Interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Textbox(label="Enter an instruction for the model:", placeholder="Type your instruction here..."),
    outputs=gr.Textbox(label="Generated Text:"),
    title="Penetration Testing AI Assistant",
    description="This tool allows you to interact with a pre-trained AI model for penetration testing assistance. Enter an instruction to generate a response.",
)

# Launch the Gradio interface
iface.launch()