Spaces:
Sleeping
Sleeping
import onnxruntime as ort | |
import numpy as np | |
from transformers import AutoTokenizer | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import uvicorn | |
from concurrent.futures import ThreadPoolExecutor | |
import tldextract | |
from fastapi.middleware.cors import CORSMiddleware | |
import asyncio | |
import time | |
import re | |
from urllib.parse import urlparse | |
import string | |
from collections import Counter | |
class ONNXPhishingDetector: | |
def __init__(self, model_path="phishing_detector.onnx"): | |
# Initialize with optimized settings and cache | |
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", local_files_only=False) | |
self.session = ort.InferenceSession( | |
model_path, | |
providers=['CPUExecutionProvider'], # Removed CoreMLExecutionProvider due to errors | |
sess_options=self._get_optimized_options() | |
) | |
self.model_expected_length = 128 | |
self.extract = tldextract.TLDExtract(include_psl_private_domains=True) | |
self.url_cache = {} # Cache for URL analysis results | |
self.thread_pool = ThreadPoolExecutor(max_workers=32) # Increased thread pool size for better parallelism | |
# Comprehensive lists for pattern matching | |
self.suspicious_keywords = { | |
'login': 'credential-related', | |
'signin': 'credential-related', | |
'account': 'credential-related', | |
'password': 'credential-related', | |
'verify': 'verification-related', | |
'secure': 'security-related', | |
'banking': 'financial-related', | |
'paypal': 'financial-related', | |
'wallet': 'financial-related', | |
'bitcoin': 'cryptocurrency-related', | |
'crypto': 'cryptocurrency-related', | |
'authenticate': 'authentication-related', | |
'authorize': 'authentication-related', | |
'validation': 'verification-related', | |
'confirm': 'verification-related' | |
} | |
self.legitimate_tlds = {'.com', '.org', '.net', '.edu', '.gov', '.mil', '.int'} | |
self.suspicious_tlds = {'.xyz', '.top', '.buzz', '.country', '.stream', '.gq', '.tk', '.ml'} | |
# Brand protection patterns | |
self.common_brands = { | |
'google', 'facebook', 'apple', 'microsoft', 'amazon', 'paypal', | |
'netflix', 'linkedin', 'twitter', 'instagram' | |
} | |
def _get_optimized_options(self): | |
# Optimize ONNX session options | |
options = ort.SessionOptions() | |
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
# Note: Changing these thread values doesn't significantly impact performance | |
options.intra_op_num_threads = 4 | |
options.inter_op_num_threads = 4 | |
options.enable_mem_pattern = True | |
options.enable_cpu_mem_arena = True | |
return options | |
def _calculate_entropy(self, text): | |
"""Calculate Shannon entropy of domain to detect random-looking strings""" | |
prob = [float(text.count(c)) / len(text) for c in set(text)] | |
entropy = -sum(p * np.log2(p) for p in prob) | |
return entropy | |
def _check_character_distribution(self, domain): | |
"""Analyze character distribution patterns""" | |
if not domain: # Handle empty domain case | |
return 0.0, 0.0 | |
char_counts = Counter(domain) | |
total_chars = len(domain) | |
# Check for unusual character distributions | |
digit_ratio = sum(c.isdigit() for c in domain) / total_chars | |
consonant_ratio = sum(c in 'bcdfghjklmnpqrstvwxyz' for c in domain.lower()) / total_chars | |
return digit_ratio, consonant_ratio | |
def _analyze_url_structure(self, url, ext): | |
reasons = [] | |
parsed = urlparse(url) | |
domain = ext.domain | |
# 1. Domain Analysis | |
domain_length = len(domain) | |
entropy = self._calculate_entropy(domain) | |
digit_ratio, consonant_ratio = self._check_character_distribution(domain) | |
# Check domain composition | |
if domain_length > 20: | |
reasons.append(f"Suspicious: Domain length ({domain_length} chars) exceeds normal range") | |
if entropy > 4.5: | |
reasons.append(f"Suspicious: High domain entropy ({entropy:.2f}) suggests randomly generated name") | |
if digit_ratio > 0.4: | |
reasons.append(f"Suspicious: Unusual number of digits ({digit_ratio:.1%} of domain)") | |
if consonant_ratio > 0.7: | |
reasons.append(f"Suspicious: Unusual consonant pattern ({consonant_ratio:.1%} of domain)") | |
# 2. Brand Impersonation Detection | |
for brand in self.common_brands: | |
if brand in domain and brand != domain: | |
if re.search(f"{brand}[^a-zA-Z]", domain) or re.search(f"[^a-zA-Z]{brand}", domain): | |
reasons.append(f"High Risk: Potential brand impersonation of {brand}") | |
# 3. URL Component Analysis | |
if parsed.username or parsed.password: | |
reasons.append("High Risk: URL contains embedded credentials") | |
if parsed.port and parsed.port not in (80, 443): | |
reasons.append(f"Suspicious: Non-standard port number ({parsed.port})") | |
# 4. Path Analysis | |
if parsed.path: | |
path_segments = parsed.path.split('/') | |
if len(path_segments) > 4: | |
reasons.append(f"Suspicious: Deep URL structure ({len(path_segments)} levels)") | |
# Check for suspicious file extensions | |
if any(segment.endswith(('.exe', '.dll', '.bat', '.sh')) for segment in path_segments): | |
reasons.append("High Risk: Contains executable file extension") | |
# 5. Query Parameter Analysis | |
if parsed.query: | |
query_params = parsed.query.split('&') | |
suspicious_params = [p for p in query_params if any(k in p.lower() for k in ['pass', 'pwd', 'token', 'key'])] | |
if suspicious_params: | |
reasons.append("Suspicious: Query contains sensitive parameter names") | |
# 6. Special Pattern Detection | |
if len(re.findall(r'[.-]', domain)) > 4: | |
reasons.append("Suspicious: Excessive use of dots/hyphens in domain") | |
if re.search(r'([a-zA-Z0-9])\1{3,}', domain): | |
reasons.append("Suspicious: Repeated character pattern detected") | |
# 7. TLD Analysis | |
if ext.suffix in self.suspicious_tlds: | |
reasons.append(f"Suspicious: Known high-risk TLD (.{ext.suffix})") | |
elif ext.suffix not in [tld.strip('.') for tld in self.legitimate_tlds]: | |
reasons.append(f"Suspicious: Uncommon TLD (.{ext.suffix})") | |
# 8. Keyword Analysis | |
found_keywords = [] | |
for keyword, category in self.suspicious_keywords.items(): | |
if keyword in f"{domain}{parsed.path}".lower(): | |
found_keywords.append(f"{keyword} ({category})") | |
if found_keywords: | |
reasons.append(f"Suspicious: Contains sensitive keywords: {', '.join(found_keywords)}") | |
return reasons | |
def _batch_preprocess(self, urls): | |
processed = [] | |
for url in urls: | |
url = url.strip().lower() | |
if not url.startswith(('http://', 'https://')): | |
url = f'http://{url}' | |
processed.append(url) | |
return processed | |
def _batch_tokenize(self, urls): | |
return self.tokenizer( | |
urls, | |
max_length=self.model_expected_length, | |
truncation=True, | |
padding="max_length", | |
return_tensors="np" | |
) | |
def _predict_thread(self, urls): | |
"""Process a batch of URLs in a separate thread""" | |
processed_urls = self._batch_preprocess(urls) | |
inputs = self._batch_tokenize(processed_urls) | |
ort_inputs = { | |
"input_ids": inputs["input_ids"].astype(np.int64), | |
"attention_mask": inputs["attention_mask"].astype(np.int64) | |
} | |
try: | |
logits = self.session.run(None, ort_inputs)[0] | |
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) | |
results = [] | |
for url, prob in zip(urls, probabilities[:, 1]): | |
ext = self.extract(url) | |
reasons = self._analyze_url_structure(url, ext) | |
if prob > 0.99: | |
reasons.append(f"Critical: ML model detected strong phishing patterns (confidence: {prob:.2%})") | |
verdict = "phishing" | |
else: | |
if not reasons: | |
reasons = ["No suspicious patterns detected"] | |
verdict = "legitimate" | |
result = { | |
"url": url, | |
"verdict": verdict, | |
"confidence": float(prob), | |
"reasons": set(reasons) | |
} | |
self.url_cache[url] = result | |
results.append(result) | |
return results | |
except Exception as e: | |
# Fallback to rule-based analysis if model inference fails | |
results = [] | |
for url in urls: | |
ext = self.extract(url) | |
reasons = self._analyze_url_structure(url, ext) | |
# Determine verdict based on rule analysis only | |
if any("High Risk" in reason for reason in reasons): | |
verdict = "phishing" | |
confidence = 0.95 | |
elif len(reasons) > 2: | |
verdict = "phishing" | |
confidence = 0.85 | |
else: | |
verdict = "legitimate" | |
confidence = 0.70 | |
if not reasons: | |
reasons = ["No suspicious patterns detected"] | |
result = { | |
"url": url, | |
"verdict": verdict, | |
"confidence": float(confidence), | |
"reasons": set(reasons + ["Note: Using rule-based analysis due to model inference error"]) | |
} | |
self.url_cache[url] = result | |
results.append(result) | |
return results | |
async def _batch_predict(self, inputs): | |
ort_inputs = { | |
"input_ids": inputs["input_ids"].astype(np.int64), | |
"attention_mask": inputs["attention_mask"].astype(np.int64) | |
} | |
return self.session.run(None, ort_inputs)[0] | |
async def _batch_analyze(self, urls): | |
processed_urls = self._batch_preprocess(urls) | |
inputs = self._batch_tokenize(processed_urls) | |
logits = await self._batch_predict(inputs) | |
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) | |
return probabilities[:, 1] | |
async def analyze_batch(self, urls): | |
results = [] | |
uncached_urls = [] | |
# Check cache first | |
for url in urls: | |
if url in self.url_cache: | |
results.append(self.url_cache[url]) | |
else: | |
uncached_urls.append(url) | |
if uncached_urls: | |
# Split URLs into smaller batches for multithreaded processing | |
batch_size = 10 # Process 10 URLs per thread | |
url_batches = [uncached_urls[i:i+batch_size] for i in range(0, len(uncached_urls), batch_size)] | |
# Submit each batch to thread pool | |
futures = [] | |
for batch in url_batches: | |
futures.append(self.thread_pool.submit(self._predict_thread, batch)) | |
# Collect results from all threads | |
for future in futures: | |
try: | |
batch_results = future.result() | |
results.extend(batch_results) | |
except Exception as e: | |
# Handle any unexpected errors in thread execution | |
print(f"Error processing batch: {str(e)}") | |
# Create fallback results for this batch | |
for url in batch: | |
results.append({ | |
"url": url, | |
"verdict": "error", | |
"confidence": 0.0, | |
"reasons": {f"Error analyzing URL: {str(e)}"} | |
}) | |
return results | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
detector = ONNXPhishingDetector() | |
class UrlList(BaseModel): | |
urls: list[str] | |
async def scan_urls(url_list: UrlList): | |
start_time = time.time() | |
# Process all URLs in a single batch with internal multithreading | |
results = await detector.analyze_batch(url_list.urls) | |
phishing_count = sum(1 for r in results if r["verdict"] == "phishing") | |
avg_confidence = sum(r["confidence"] for r in results) / len(results) if results else 0 | |
if avg_confidence >= 0.99: | |
overall_verdict = "malicious" | |
else: | |
overall_verdict = "safe" | |
return { | |
"time_taken": f"{time.time() - start_time:.2f}s", | |
"total_urls": len(url_list.urls), | |
"legitimate": len(url_list.urls) - phishing_count, | |
"phishing": phishing_count, | |
"overall_verdict": overall_verdict, | |
"average_confidence": avg_confidence, | |
"results": results | |
} | |
# To run this file in terminal: | |
# 1. Make sure you have all dependencies installed: | |
# pip install fastapi uvicorn onnxruntime numpy transformers tldextract | |
# 2. Navigate to the directory containing this file | |
# 3. Run the command: | |
# python -m backend.main | |
# or if you're already in the backend directory: | |
# python main.py | |
# 4. The API will be available at http://localhost:8000 | |
# 5. You can test it with curl: | |
# curl -X POST "http://localhost:8000/scan" -H "Content-Type: application/json" -d '{"urls":["google.com", "suspicious-phishing-site.xyz"]}' | |
# 6. Or use tools like Postman to send POST requests to the /scan endpoint | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) |