Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch.nn.functional as F | |
| import spacy | |
| from typing import List, Dict, Tuple | |
| import logging | |
| import os | |
| import gradio as gr | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from concurrent.futures import ThreadPoolExecutor | |
| from functools import partial | |
| import time | |
| from datetime import datetime | |
| import openpyxl | |
| from openpyxl import Workbook | |
| from openpyxl.utils import get_column_letter | |
| from io import BytesIO | |
| import base64 | |
| import hashlib | |
| import requests | |
| import tempfile | |
| from pathlib import Path | |
| import mimetypes | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| MAX_LENGTH = 512 | |
| MODEL_NAME = "microsoft/deberta-v3-small" | |
| WINDOW_SIZE = 6 | |
| WINDOW_OVERLAP = 2 | |
| CONFIDENCE_THRESHOLD = 0.65 | |
| BATCH_SIZE = 8 # Reduced batch size for CPU | |
| MAX_WORKERS = 4 # Number of worker threads for processing | |
| # IMPORTANT: Set PyTorch thread configuration at the module level | |
| # before any parallel work starts | |
| if not torch.cuda.is_available(): | |
| # Set thread configuration only once at the beginning | |
| torch.set_num_threads(MAX_WORKERS) | |
| try: | |
| # Only set interop threads if it hasn't been set already | |
| torch.set_num_interop_threads(MAX_WORKERS) | |
| except RuntimeError as e: | |
| logger.warning(f"Could not set interop threads: {str(e)}") | |
| # Get password hash from environment variable (more secure) | |
| ADMIN_PASSWORD_HASH = os.environ.get('ADMIN_PASSWORD_HASH') | |
| if not ADMIN_PASSWORD_HASH: | |
| ADMIN_PASSWORD_HASH = "5e22d1ed71b273b1b2b5331f2d3e0f6cf34595236f201c6924d6bc81de27cdcb" | |
| # Excel file path for logs | |
| EXCEL_LOG_PATH = "/tmp/prediction_logs.xlsx" | |
| # OCR API settings | |
| OCR_API_KEY = "9e11346f1288957" # This is a partial key - replace with the full one | |
| OCR_API_ENDPOINT = "https://api.ocr.space/parse/image" | |
| OCR_MAX_PDF_PAGES = 3 | |
| OCR_MAX_FILE_SIZE_MB = 1 | |
| # Configure logging for OCR module | |
| ocr_logger = logging.getLogger("ocr_module") | |
| ocr_logger.setLevel(logging.INFO) | |
| class OCRProcessor: | |
| """ | |
| Handles OCR processing of image and document files using OCR.space API | |
| """ | |
| def __init__(self, api_key: str = OCR_API_KEY): | |
| self.api_key = api_key | |
| self.endpoint = OCR_API_ENDPOINT | |
| def process_file(self, file_path: str) -> Dict: | |
| """ | |
| Process a file using OCR.space API | |
| """ | |
| start_time = time.time() | |
| ocr_logger.info(f"Starting OCR processing for file: {os.path.basename(file_path)}") | |
| # Validate file size | |
| file_size_mb = os.path.getsize(file_path) / (1024 * 1024) | |
| if file_size_mb > OCR_MAX_FILE_SIZE_MB: | |
| ocr_logger.warning(f"File size ({file_size_mb:.2f} MB) exceeds limit of {OCR_MAX_FILE_SIZE_MB} MB") | |
| return { | |
| "success": False, | |
| "error": f"File size ({file_size_mb:.2f} MB) exceeds limit of {OCR_MAX_FILE_SIZE_MB} MB", | |
| "text": "" | |
| } | |
| # Determine file type and handle accordingly | |
| file_type = self._get_file_type(file_path) | |
| ocr_logger.info(f"Detected file type: {file_type}") | |
| # Prepare the API request | |
| with open(file_path, 'rb') as f: | |
| file_data = f.read() | |
| # Set up API parameters | |
| payload = { | |
| 'isOverlayRequired': 'false', | |
| 'language': 'eng', | |
| 'OCREngine': '2', # Use more accurate engine | |
| 'scale': 'true', | |
| 'detectOrientation': 'true', | |
| } | |
| # For PDF files, check page count limitations | |
| if file_type == 'application/pdf': | |
| ocr_logger.info("PDF document detected, enforcing page limit") | |
| payload['filetype'] = 'PDF' | |
| # Prepare file for OCR API | |
| files = { | |
| 'file': (os.path.basename(file_path), file_data, file_type) | |
| } | |
| headers = { | |
| 'apikey': self.api_key, | |
| } | |
| # Make the OCR API request | |
| try: | |
| ocr_logger.info("Sending request to OCR.space API") | |
| response = requests.post( | |
| self.endpoint, | |
| files=files, | |
| data=payload, | |
| headers=headers | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| # Process the OCR results | |
| if result.get('OCRExitCode') in [1, 2]: # Success or partial success | |
| extracted_text = self._extract_text_from_result(result) | |
| processing_time = time.time() - start_time | |
| ocr_logger.info(f"OCR processing completed in {processing_time:.2f} seconds") | |
| return { | |
| "success": True, | |
| "text": extracted_text, | |
| "word_count": len(extracted_text.split()), | |
| "processing_time_ms": int(processing_time * 1000) | |
| } | |
| else: | |
| ocr_logger.error(f"OCR API error: {result.get('ErrorMessage', 'Unknown error')}") | |
| return { | |
| "success": False, | |
| "error": result.get('ErrorMessage', 'OCR processing failed'), | |
| "text": "" | |
| } | |
| except requests.exceptions.RequestException as e: | |
| ocr_logger.error(f"OCR API request failed: {str(e)}") | |
| return { | |
| "success": False, | |
| "error": f"OCR API request failed: {str(e)}", | |
| "text": "" | |
| } | |
| def _extract_text_from_result(self, result: Dict) -> str: | |
| """ | |
| Extract all text from the OCR API result | |
| """ | |
| extracted_text = "" | |
| if 'ParsedResults' in result and result['ParsedResults']: | |
| for parsed_result in result['ParsedResults']: | |
| if parsed_result.get('ParsedText'): | |
| extracted_text += parsed_result['ParsedText'] | |
| return extracted_text | |
| def _get_file_type(self, file_path: str) -> str: | |
| """ | |
| Determine MIME type of a file | |
| """ | |
| mime_type, _ = mimetypes.guess_type(file_path) | |
| if mime_type is None: | |
| # Default to binary if MIME type can't be determined | |
| return 'application/octet-stream' | |
| return mime_type | |
| def is_admin_password(input_text: str) -> bool: | |
| """ | |
| Check if the input text matches the admin password using secure hash comparison. | |
| """ | |
| # Hash the input text | |
| input_hash = hashlib.sha256(input_text.strip().encode()).hexdigest() | |
| # Compare hashes (constant-time comparison to prevent timing attacks) | |
| return input_hash == ADMIN_PASSWORD_HASH | |
| class TextWindowProcessor: | |
| def __init__(self): | |
| try: | |
| self.nlp = spacy.load("en_core_web_sm") | |
| except OSError: | |
| logger.info("Downloading spacy model...") | |
| spacy.cli.download("en_core_web_sm") | |
| self.nlp = spacy.load("en_core_web_sm") | |
| if 'sentencizer' not in self.nlp.pipe_names: | |
| self.nlp.add_pipe('sentencizer') | |
| disabled_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'sentencizer'] | |
| self.nlp.disable_pipes(*disabled_pipes) | |
| # Initialize thread pool for parallel processing | |
| self.executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| def split_into_sentences(self, text: str) -> List[str]: | |
| doc = self.nlp(text) | |
| return [str(sent).strip() for sent in doc.sents] | |
| def create_windows(self, sentences: List[str], window_size: int, overlap: int) -> List[str]: | |
| if len(sentences) < window_size: | |
| return [" ".join(sentences)] | |
| windows = [] | |
| stride = window_size - overlap | |
| for i in range(0, len(sentences) - window_size + 1, stride): | |
| window = sentences[i:i + window_size] | |
| windows.append(" ".join(window)) | |
| return windows | |
| def create_centered_windows(self, sentences: List[str], window_size: int) -> Tuple[List[str], List[List[int]]]: | |
| """Create windows with better boundary handling""" | |
| windows = [] | |
| window_sentence_indices = [] | |
| for i in range(len(sentences)): | |
| # Calculate window boundaries centered on current sentence | |
| half_window = window_size // 2 | |
| start_idx = max(0, i - half_window) | |
| end_idx = min(len(sentences), i + half_window + 1) | |
| # Create the window | |
| window = sentences[start_idx:end_idx] | |
| windows.append(" ".join(window)) | |
| window_sentence_indices.append(list(range(start_idx, end_idx))) | |
| return windows, window_sentence_indices | |
| class TextClassifier: | |
| def __init__(self): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model_name = MODEL_NAME | |
| self.tokenizer = None | |
| self.model = None | |
| self.processor = TextWindowProcessor() | |
| self.initialize_model() | |
| def initialize_model(self): | |
| """Initialize the model and tokenizer.""" | |
| logger.info("Initializing model and tokenizer...") | |
| from transformers import DebertaV2TokenizerFast | |
| self.tokenizer = DebertaV2TokenizerFast.from_pretrained( | |
| self.model_name, | |
| model_max_length=MAX_LENGTH, | |
| use_fast=True | |
| ) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| self.model_name, | |
| num_labels=2 | |
| ).to(self.device) | |
| model_path = "model_20250209_184929_acc1.0000.pt" | |
| if os.path.exists(model_path): | |
| logger.info(f"Loading custom model from {model_path}") | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| logger.warning("Custom model file not found. Using base model.") | |
| self.model.eval() | |
| # [Other TextClassifier methods remain the same as in paste.txt] | |
| def quick_scan(self, text: str) -> Dict: | |
| """Perform a quick scan using simple window analysis.""" | |
| if not text.strip(): | |
| return { | |
| 'prediction': 'unknown', | |
| 'confidence': 0.0, | |
| 'num_windows': 0 | |
| } | |
| sentences = self.processor.split_into_sentences(text) | |
| windows = self.processor.create_windows(sentences, WINDOW_SIZE, WINDOW_OVERLAP) | |
| predictions = [] | |
| # Process windows in smaller batches for CPU efficiency | |
| for i in range(0, len(windows), BATCH_SIZE): | |
| batch_windows = windows[i:i + BATCH_SIZE] | |
| inputs = self.tokenizer( | |
| batch_windows, | |
| truncation=True, | |
| padding=True, | |
| max_length=MAX_LENGTH, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1) | |
| for idx, window in enumerate(batch_windows): | |
| prediction = { | |
| 'window': window, | |
| 'human_prob': probs[idx][1].item(), | |
| 'ai_prob': probs[idx][0].item(), | |
| 'prediction': 'human' if probs[idx][1] > probs[idx][0] else 'ai' | |
| } | |
| predictions.append(prediction) | |
| # Clean up GPU memory if available | |
| del inputs, outputs, probs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if not predictions: | |
| return { | |
| 'prediction': 'unknown', | |
| 'confidence': 0.0, | |
| 'num_windows': 0 | |
| } | |
| avg_human_prob = sum(p['human_prob'] for p in predictions) / len(predictions) | |
| avg_ai_prob = sum(p['ai_prob'] for p in predictions) / len(predictions) | |
| return { | |
| 'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', | |
| 'confidence': max(avg_human_prob, avg_ai_prob), | |
| 'num_windows': len(predictions) | |
| } | |
| def detailed_scan(self, text: str) -> Dict: | |
| """Perform a detailed scan with improved sentence-level analysis.""" | |
| # Clean up trailing whitespace | |
| text = text.rstrip() | |
| if not text.strip(): | |
| return { | |
| 'sentence_predictions': [], | |
| 'highlighted_text': '', | |
| 'full_text': '', | |
| 'overall_prediction': { | |
| 'prediction': 'unknown', | |
| 'confidence': 0.0, | |
| 'num_sentences': 0 | |
| } | |
| } | |
| sentences = self.processor.split_into_sentences(text) | |
| if not sentences: | |
| return {} | |
| # Create centered windows for each sentence | |
| windows, window_sentence_indices = self.processor.create_centered_windows(sentences, WINDOW_SIZE) | |
| # Track scores for each sentence | |
| sentence_appearances = {i: 0 for i in range(len(sentences))} | |
| sentence_scores = {i: {'human_prob': 0.0, 'ai_prob': 0.0} for i in range(len(sentences))} | |
| # Process windows in batches | |
| for i in range(0, len(windows), BATCH_SIZE): | |
| batch_windows = windows[i:i + BATCH_SIZE] | |
| batch_indices = window_sentence_indices[i:i + BATCH_SIZE] | |
| inputs = self.tokenizer( | |
| batch_windows, | |
| truncation=True, | |
| padding=True, | |
| max_length=MAX_LENGTH, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1) | |
| # Attribute predictions with weighted scoring | |
| for window_idx, indices in enumerate(batch_indices): | |
| center_idx = len(indices) // 2 | |
| center_weight = 0.7 # Higher weight for center sentence | |
| edge_weight = 0.3 / (len(indices) - 1) if len(indices) > 1 else 0 # Distribute remaining weight | |
| for pos, sent_idx in enumerate(indices): | |
| # Apply higher weight to center sentence | |
| weight = center_weight if pos == center_idx else edge_weight | |
| sentence_appearances[sent_idx] += weight | |
| sentence_scores[sent_idx]['human_prob'] += weight * probs[window_idx][1].item() | |
| sentence_scores[sent_idx]['ai_prob'] += weight * probs[window_idx][0].item() | |
| # Clean up memory | |
| del inputs, outputs, probs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Calculate final predictions with boundary smoothing | |
| sentence_predictions = [] | |
| for i in range(len(sentences)): | |
| if sentence_appearances[i] > 0: | |
| human_prob = sentence_scores[i]['human_prob'] / sentence_appearances[i] | |
| ai_prob = sentence_scores[i]['ai_prob'] / sentence_appearances[i] | |
| # Apply minimal smoothing at prediction boundaries | |
| if i > 0 and i < len(sentences) - 1: | |
| prev_human = sentence_scores[i-1]['human_prob'] / max(sentence_appearances[i-1], 1e-10) | |
| prev_ai = sentence_scores[i-1]['ai_prob'] / max(sentence_appearances[i-1], 1e-10) | |
| next_human = sentence_scores[i+1]['human_prob'] / max(sentence_appearances[i+1], 1e-10) | |
| next_ai = sentence_scores[i+1]['ai_prob'] / max(sentence_appearances[i+1], 1e-10) | |
| # Check if we're at a prediction boundary | |
| current_pred = 'human' if human_prob > ai_prob else 'ai' | |
| prev_pred = 'human' if prev_human > prev_ai else 'ai' | |
| next_pred = 'human' if next_human > next_ai else 'ai' | |
| if current_pred != prev_pred or current_pred != next_pred: | |
| # Small adjustment at boundaries | |
| smooth_factor = 0.1 | |
| human_prob = (human_prob * (1 - smooth_factor) + | |
| (prev_human + next_human) * smooth_factor / 2) | |
| ai_prob = (ai_prob * (1 - smooth_factor) + | |
| (prev_ai + next_ai) * smooth_factor / 2) | |
| sentence_predictions.append({ | |
| 'sentence': sentences[i], | |
| 'human_prob': human_prob, | |
| 'ai_prob': ai_prob, | |
| 'prediction': 'human' if human_prob > ai_prob else 'ai', | |
| 'confidence': max(human_prob, ai_prob) | |
| }) | |
| return { | |
| 'sentence_predictions': sentence_predictions, | |
| 'highlighted_text': self.format_predictions_html(sentence_predictions), | |
| 'full_text': text, | |
| 'overall_prediction': self.aggregate_predictions(sentence_predictions) | |
| } | |
| def format_predictions_html(self, sentence_predictions: List[Dict]) -> str: | |
| """Format predictions as HTML with color-coding.""" | |
| html_parts = [] | |
| for pred in sentence_predictions: | |
| sentence = pred['sentence'] | |
| confidence = pred['confidence'] | |
| if confidence >= CONFIDENCE_THRESHOLD: | |
| if pred['prediction'] == 'human': | |
| color = "#90EE90" # Light green | |
| else: | |
| color = "#FFB6C6" # Light red | |
| else: | |
| if pred['prediction'] == 'human': | |
| color = "#E8F5E9" # Very light green | |
| else: | |
| color = "#FFEBEE" # Very light red | |
| html_parts.append(f'<span style="background-color: {color};">{sentence}</span>') | |
| return " ".join(html_parts) | |
| def aggregate_predictions(self, predictions: List[Dict]) -> Dict: | |
| """Aggregate predictions from multiple sentences into a single prediction.""" | |
| if not predictions: | |
| return { | |
| 'prediction': 'unknown', | |
| 'confidence': 0.0, | |
| 'num_sentences': 0 | |
| } | |
| total_human_prob = sum(p['human_prob'] for p in predictions) | |
| total_ai_prob = sum(p['ai_prob'] for p in predictions) | |
| num_sentences = len(predictions) | |
| avg_human_prob = total_human_prob / num_sentences | |
| avg_ai_prob = total_ai_prob / num_sentences | |
| return { | |
| 'prediction': 'human' if avg_human_prob > avg_ai_prob else 'ai', | |
| 'confidence': max(avg_human_prob, avg_ai_prob), | |
| 'num_sentences': num_sentences | |
| } | |
| # Function to handle file upload, OCR processing, and text analysis | |
| def handle_file_upload_and_analyze(file_obj, mode: str, classifier) -> tuple: | |
| """ | |
| Handle file upload, OCR processing, and text analysis | |
| """ | |
| if file_obj is None: | |
| return ( | |
| "No file uploaded", | |
| "Please upload a file to analyze", | |
| "No file uploaded for analysis" | |
| ) | |
| # Create a temporary file with an appropriate extension based on content | |
| content_start = file_obj[:20] # Look at the first few bytes | |
| # Default to .bin extension | |
| file_ext = ".bin" | |
| # Try to detect PDF files | |
| if content_start.startswith(b'%PDF'): | |
| file_ext = ".pdf" | |
| # For images, detect by common magic numbers | |
| elif content_start.startswith(b'\xff\xd8'): # JPEG | |
| file_ext = ".jpg" | |
| elif content_start.startswith(b'\x89PNG'): # PNG | |
| file_ext = ".png" | |
| elif content_start.startswith(b'GIF'): # GIF | |
| file_ext = ".gif" | |
| # Create a temporary file with the detected extension | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file: | |
| temp_file_path = temp_file.name | |
| # Write uploaded file data to the temporary file | |
| temp_file.write(file_obj) | |
| try: | |
| # Process the file with OCR | |
| ocr_processor = OCRProcessor() | |
| ocr_result = ocr_processor.process_file(temp_file_path) | |
| if not ocr_result["success"]: | |
| return ( | |
| "OCR Processing Error", | |
| ocr_result["error"], | |
| "Failed to extract text from the uploaded file" | |
| ) | |
| # Get the extracted text | |
| extracted_text = ocr_result["text"] | |
| # If no text was extracted | |
| if not extracted_text.strip(): | |
| return ( | |
| "No text extracted", | |
| "The OCR process did not extract any text from the uploaded file.", | |
| "No text was found in the uploaded file" | |
| ) | |
| # Call the original text analysis function with the extracted text | |
| return analyze_text(extracted_text, mode, classifier) | |
| finally: | |
| # Clean up the temporary file | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| def initialize_excel_log(): | |
| """Initialize the Excel log file if it doesn't exist.""" | |
| if not os.path.exists(EXCEL_LOG_PATH): | |
| wb = Workbook() | |
| ws = wb.active | |
| ws.title = "Prediction Logs" | |
| # Set column headers | |
| headers = ["timestamp", "word_count", "prediction", "confidence", | |
| "execution_time_ms", "analysis_mode", "full_text"] | |
| for col_num, header in enumerate(headers, 1): | |
| ws.cell(row=1, column=col_num, value=header) | |
| # Adjust column widths for better readability | |
| ws.column_dimensions[get_column_letter(1)].width = 20 # timestamp | |
| ws.column_dimensions[get_column_letter(2)].width = 10 # word_count | |
| ws.column_dimensions[get_column_letter(3)].width = 10 # prediction | |
| ws.column_dimensions[get_column_letter(4)].width = 10 # confidence | |
| ws.column_dimensions[get_column_letter(5)].width = 15 # execution_time_ms | |
| ws.column_dimensions[get_column_letter(6)].width = 15 # analysis_mode | |
| ws.column_dimensions[get_column_letter(7)].width = 100 # full_text | |
| # Save the workbook | |
| wb.save(EXCEL_LOG_PATH) | |
| logger.info(f"Initialized Excel log file at {EXCEL_LOG_PATH}") | |
| def log_prediction_data(input_text, word_count, prediction, confidence, execution_time, mode): | |
| """Log prediction data to an Excel file in the /tmp directory.""" | |
| # Initialize the Excel file if it doesn't exist | |
| if not os.path.exists(EXCEL_LOG_PATH): | |
| initialize_excel_log() | |
| try: | |
| # Load the existing workbook | |
| wb = openpyxl.load_workbook(EXCEL_LOG_PATH) | |
| ws = wb.active | |
| # Get the next row number | |
| next_row = ws.max_row + 1 | |
| # Clean up the input text for Excel storage (replace problematic characters) | |
| cleaned_text = input_text.replace("\n", " ") | |
| # Prepare row data | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| row_data = [ | |
| timestamp, | |
| word_count, | |
| prediction, | |
| f"{confidence:.2f}", | |
| f"{execution_time:.2f}", | |
| mode, | |
| cleaned_text | |
| ] | |
| # Add the data to the worksheet | |
| for col_num, value in enumerate(row_data, 1): | |
| ws.cell(row=next_row, column=col_num, value=value) | |
| # Save the workbook | |
| wb.save(EXCEL_LOG_PATH) | |
| logger.info(f"Successfully logged prediction data to {EXCEL_LOG_PATH}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error logging prediction data to Excel: {str(e)}") | |
| return False | |
| def get_logs_as_base64(): | |
| """Read the Excel logs file and return as base64 for downloading.""" | |
| if not os.path.exists(EXCEL_LOG_PATH): | |
| return None | |
| try: | |
| # Read the Excel file into memory | |
| with open(EXCEL_LOG_PATH, "rb") as f: | |
| file_data = f.read() | |
| # Encode the file as base64 | |
| base64_data = base64.b64encode(file_data).decode('utf-8') | |
| return base64_data | |
| except Exception as e: | |
| logger.error(f"Error reading Excel logs: {str(e)}") | |
| return None | |
| def analyze_text(text: str, mode: str, classifier: TextClassifier) -> tuple: | |
| """Analyze text using specified mode and return formatted results.""" | |
| # Check if the input text matches the admin password using secure comparison | |
| if is_admin_password(text.strip()): | |
| # Return logs instead of analysis | |
| base64_data = get_logs_as_base64() | |
| logs_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| logs_filename = f"prediction_logs_{logs_timestamp}.xlsx" | |
| if base64_data: | |
| # Create downloadable HTML with the logs | |
| html_content = f""" | |
| <div style="background-color: #e6f7ff; padding: 15px; border-radius: 5px;"> | |
| <h3>Admin Access Granted - Prediction Logs</h3> | |
| <p>Logs retrieved at: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</p> | |
| <p>Excel file contains all prediction data with full text of all submissions.</p> | |
| <a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{base64_data}" | |
| download="{logs_filename}" | |
| style="display: inline-block; margin-top: 10px; padding: 10px 15px; | |
| background-color: #4CAF50; color: white; text-decoration: none; | |
| border-radius: 4px;"> | |
| Download Excel Logs | |
| </a> | |
| </div> | |
| """ | |
| else: | |
| html_content = """ | |
| <div style="background-color: #ffe6e6; padding: 15px; border-radius: 5px;"> | |
| <h3>Admin Access Granted - No Logs Found</h3> | |
| <p>No prediction logs were found or there was an error reading the logs file.</p> | |
| </div> | |
| """ | |
| # Return special admin output instead of normal analysis | |
| return ( | |
| html_content, | |
| f"Admin access granted. Logs retrieved at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", | |
| f"ADMIN MODE\nLogs available for download\nFile: {EXCEL_LOG_PATH}" | |
| ) | |
| # Start timing for normal analysis | |
| start_time = time.time() | |
| # Count words in the text | |
| word_count = len(text.split()) | |
| # If text is less than 200 words and detailed mode is selected, switch to quick mode | |
| original_mode = mode | |
| if word_count < 200 and mode == "detailed": | |
| mode = "quick" | |
| if mode == "quick": | |
| result = classifier.quick_scan(text) | |
| quick_analysis = f""" | |
| PREDICTION: {result['prediction'].upper()} | |
| Confidence: {result['confidence']*100:.1f}% | |
| Windows analyzed: {result['num_windows']} | |
| """ | |
| # Add note if mode was switched | |
| if original_mode == "detailed": | |
| quick_analysis += f"\n\nNote: Switched to quick mode because text contains only {word_count} words. Minimum 200 words required for detailed analysis." | |
| # Calculate execution time in milliseconds | |
| execution_time = (time.time() - start_time) * 1000 | |
| # Log the prediction data | |
| log_prediction_data( | |
| input_text=text, | |
| word_count=word_count, | |
| prediction=result['prediction'], | |
| confidence=result['confidence'], | |
| execution_time=execution_time, | |
| mode=original_mode | |
| ) | |
| return ( | |
| text, # No highlighting in quick mode | |
| "Quick scan mode - no sentence-level analysis available", | |
| quick_analysis | |
| ) | |
| else: | |
| analysis = classifier.detailed_scan(text) | |
| detailed_analysis = [] | |
| for pred in analysis['sentence_predictions']: | |
| confidence = pred['confidence'] * 100 | |
| detailed_analysis.append(f"Sentence: {pred['sentence']}") | |
| detailed_analysis.append(f"Prediction: {pred['prediction'].upper()}") | |
| detailed_analysis.append(f"Confidence: {confidence:.1f}%") | |
| detailed_analysis.append("-" * 50) | |
| final_pred = analysis['overall_prediction'] | |
| overall_result = f""" | |
| FINAL PREDICTION: {final_pred['prediction'].upper()} | |
| Overall confidence: {final_pred['confidence']*100:.1f}% | |
| Number of sentences analyzed: {final_pred['num_sentences']} | |
| """ | |
| # Calculate execution time in milliseconds | |
| execution_time = (time.time() - start_time) * 1000 | |
| # Log the prediction data | |
| log_prediction_data( | |
| input_text=text, | |
| word_count=word_count, | |
| prediction=final_pred['prediction'], | |
| confidence=final_pred['confidence'], | |
| execution_time=execution_time, | |
| mode=original_mode | |
| ) | |
| return ( | |
| analysis['highlighted_text'], | |
| "\n".join(detailed_analysis), | |
| overall_result | |
| ) | |
| # Initialize the classifier globally | |
| classifier = TextClassifier() | |
| # Create Gradio interface with a small file upload button next to the radio buttons | |
| def setup_interface(): | |
| # Create analyzer functions that capture the classifier | |
| def analyze_text_wrapper(text, mode): | |
| return analyze_text(text, mode, classifier) | |
| def handle_file_upload_wrapper(file_obj, mode): | |
| if file_obj is None: | |
| return analyze_text_wrapper("", mode) | |
| return handle_file_upload_and_analyze(file_obj, mode, classifier) | |
| with gr.Blocks(title="AI Text Detector") as demo: | |
| gr.Markdown("# AI Text Detector") | |
| with gr.Row(): | |
| # Left column - Input | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| lines=8, | |
| placeholder="Enter text to analyze...", | |
| label="Input Text" | |
| ) | |
| # Custom container for radio buttons and small file upload | |
| with gr.Column(): | |
| gr.Markdown("Analysis Mode", elem_id="analysis-mode-label") | |
| gr.Markdown("Quick mode for faster analysis. Detailed mode for sentence-level analysis.", | |
| elem_classes=["description-text"]) | |
| with gr.Row(): | |
| mode_selection = gr.Radio( | |
| choices=["quick", "detailed"], | |
| value="quick", | |
| label="", | |
| elem_classes=["mode-radio"] | |
| ) | |
| # Tiny file button (hidden until CSS applies) | |
| file_upload = gr.File( | |
| file_types=["image", "pdf", "doc", "docx"], | |
| type="binary", | |
| label="", | |
| elem_id="tiny-file-upload" | |
| ) | |
| analyze_button = gr.Button("Analyze Text") | |
| # Right column - Results | |
| with gr.Column(): | |
| output_html = gr.HTML(label="Highlighted Analysis") | |
| output_sentences = gr.Textbox(label="Sentence-by-Sentence Analysis", lines=10) | |
| output_result = gr.Textbox(label="Overall Result", lines=4) | |
| # Connect the components | |
| analyze_button.click( | |
| analyze_text_wrapper, | |
| inputs=[text_input, mode_selection], | |
| outputs=[output_html, output_sentences, output_result] | |
| ) | |
| file_upload.change( | |
| handle_file_upload_wrapper, | |
| inputs=[file_upload, mode_selection], | |
| outputs=[output_html, output_sentences, output_result] | |
| ) | |
| # Custom CSS that completely transforms the file upload into a small icon | |
| gr.HTML(""" | |
| <style> | |
| /* Hide the default file upload completely */ | |
| #tiny-file-upload { | |
| position: relative; | |
| width: 40px !important; | |
| vertical-align: middle !important; | |
| margin-left: 10px !important; | |
| margin-top: 5px !important; | |
| } | |
| #tiny-file-upload > .wrap { | |
| position: relative; | |
| display: inline-block; | |
| width: 32px !important; | |
| height: 32px !important; | |
| overflow: hidden; | |
| } | |
| /* Hide all the default upload UI elements */ | |
| #tiny-file-upload .file-preview, | |
| #tiny-file-upload .file-preview p, | |
| #tiny-file-upload .hidden-upload, | |
| #tiny-file-upload .upload-prompt, | |
| #tiny-file-upload .remove-file, | |
| #tiny-file-upload > label { | |
| display: none !important; | |
| } | |
| /* Create a paperclip-like button */ | |
| #tiny-file-upload .wrap:before { | |
| content: "π"; | |
| font-size: 20px; | |
| position: absolute; | |
| top: 2px; | |
| left: 5px; | |
| opacity: 0.7; | |
| cursor: pointer; | |
| z-index: 10; | |
| } | |
| #tiny-file-upload .wrap:hover:before { | |
| opacity: 1; | |
| } | |
| /* Adjust upload zone to be clickable but mostly invisible */ | |
| #tiny-file-upload .upload-button { | |
| position: absolute; | |
| width: 32px !important; | |
| height: 32px !important; | |
| left: 0 !important; | |
| top: 0 !important; | |
| opacity: 0.01 !important; | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| cursor: pointer !important; | |
| z-index: 5 !important; | |
| } | |
| /* Fix the radio button alignment */ | |
| .mode-radio { | |
| margin-top: 0 !important; | |
| display: inline-block !important; | |
| } | |
| /* Make description text smaller */ | |
| .description-text { | |
| font-size: 0.85em; | |
| color: #666; | |
| margin-top: -5px; | |
| margin-bottom: 5px; | |
| } | |
| /* More compact layout */ | |
| #analysis-mode-label { | |
| margin-bottom: 0 !important; | |
| } | |
| </style> | |
| """) | |
| return demo | |
| # Setup the app with CORS middleware | |
| def setup_app(): | |
| demo = setup_interface() | |
| # Get the FastAPI app from Gradio | |
| app = demo.app | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # For development | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| return demo | |
| # Initialize the application | |
| if __name__ == "__main__": | |
| demo = setup_app() | |
| # Start the server | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |