KavishNayeem commited on
Commit
89610f9
·
1 Parent(s): 451012d

Add FastAPI ONNX backend

Browse files
Files changed (3) hide show
  1. Dockerfile +14 -0
  2. app.py +364 -0
  3. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+
14
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from transformers import AutoTokenizer
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ import uvicorn
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ import tldextract
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ import asyncio
11
+ import time
12
+ import re
13
+ from urllib.parse import urlparse
14
+ import string
15
+ from collections import Counter
16
+
17
+ class ONNXPhishingDetector:
18
+ def __init__(self, model_path="phishing_detector.onnx"):
19
+ # Initialize with optimized settings and cache
20
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", local_files_only=False)
21
+ self.session = ort.InferenceSession(
22
+ model_path,
23
+ providers=['CPUExecutionProvider'], # Removed CoreMLExecutionProvider due to errors
24
+ sess_options=self._get_optimized_options()
25
+ )
26
+ self.model_expected_length = 128
27
+ self.extract = tldextract.TLDExtract(include_psl_private_domains=True)
28
+ self.url_cache = {} # Cache for URL analysis results
29
+ self.thread_pool = ThreadPoolExecutor(max_workers=32) # Increased thread pool size for better parallelism
30
+
31
+ # Comprehensive lists for pattern matching
32
+ self.suspicious_keywords = {
33
+ 'login': 'credential-related',
34
+ 'signin': 'credential-related',
35
+ 'account': 'credential-related',
36
+ 'password': 'credential-related',
37
+ 'verify': 'verification-related',
38
+ 'secure': 'security-related',
39
+ 'banking': 'financial-related',
40
+ 'paypal': 'financial-related',
41
+ 'wallet': 'financial-related',
42
+ 'bitcoin': 'cryptocurrency-related',
43
+ 'crypto': 'cryptocurrency-related',
44
+ 'authenticate': 'authentication-related',
45
+ 'authorize': 'authentication-related',
46
+ 'validation': 'verification-related',
47
+ 'confirm': 'verification-related'
48
+ }
49
+
50
+ self.legitimate_tlds = {'.com', '.org', '.net', '.edu', '.gov', '.mil', '.int'}
51
+ self.suspicious_tlds = {'.xyz', '.top', '.buzz', '.country', '.stream', '.gq', '.tk', '.ml'}
52
+
53
+ # Brand protection patterns
54
+ self.common_brands = {
55
+ 'google', 'facebook', 'apple', 'microsoft', 'amazon', 'paypal',
56
+ 'netflix', 'linkedin', 'twitter', 'instagram'
57
+ }
58
+
59
+ def _get_optimized_options(self):
60
+ # Optimize ONNX session options
61
+ options = ort.SessionOptions()
62
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
63
+ # Note: Changing these thread values doesn't significantly impact performance
64
+ options.intra_op_num_threads = 4
65
+ options.inter_op_num_threads = 4
66
+ options.enable_mem_pattern = True
67
+ options.enable_cpu_mem_arena = True
68
+ return options
69
+
70
+ def _calculate_entropy(self, text):
71
+ """Calculate Shannon entropy of domain to detect random-looking strings"""
72
+ prob = [float(text.count(c)) / len(text) for c in set(text)]
73
+ entropy = -sum(p * np.log2(p) for p in prob)
74
+ return entropy
75
+
76
+ def _check_character_distribution(self, domain):
77
+ """Analyze character distribution patterns"""
78
+ if not domain: # Handle empty domain case
79
+ return 0.0, 0.0
80
+
81
+ char_counts = Counter(domain)
82
+ total_chars = len(domain)
83
+
84
+ # Check for unusual character distributions
85
+ digit_ratio = sum(c.isdigit() for c in domain) / total_chars
86
+ consonant_ratio = sum(c in 'bcdfghjklmnpqrstvwxyz' for c in domain.lower()) / total_chars
87
+
88
+ return digit_ratio, consonant_ratio
89
+
90
+ def _analyze_url_structure(self, url, ext):
91
+ reasons = []
92
+ parsed = urlparse(url)
93
+ domain = ext.domain
94
+
95
+ # 1. Domain Analysis
96
+ domain_length = len(domain)
97
+ entropy = self._calculate_entropy(domain)
98
+ digit_ratio, consonant_ratio = self._check_character_distribution(domain)
99
+
100
+ # Check domain composition
101
+ if domain_length > 20:
102
+ reasons.append(f"Suspicious: Domain length ({domain_length} chars) exceeds normal range")
103
+
104
+ if entropy > 4.5:
105
+ reasons.append(f"Suspicious: High domain entropy ({entropy:.2f}) suggests randomly generated name")
106
+
107
+ if digit_ratio > 0.4:
108
+ reasons.append(f"Suspicious: Unusual number of digits ({digit_ratio:.1%} of domain)")
109
+
110
+ if consonant_ratio > 0.7:
111
+ reasons.append(f"Suspicious: Unusual consonant pattern ({consonant_ratio:.1%} of domain)")
112
+
113
+ # 2. Brand Impersonation Detection
114
+ for brand in self.common_brands:
115
+ if brand in domain and brand != domain:
116
+ if re.search(f"{brand}[^a-zA-Z]", domain) or re.search(f"[^a-zA-Z]{brand}", domain):
117
+ reasons.append(f"High Risk: Potential brand impersonation of {brand}")
118
+
119
+ # 3. URL Component Analysis
120
+ if parsed.username or parsed.password:
121
+ reasons.append("High Risk: URL contains embedded credentials")
122
+
123
+ if parsed.port and parsed.port not in (80, 443):
124
+ reasons.append(f"Suspicious: Non-standard port number ({parsed.port})")
125
+
126
+ # 4. Path Analysis
127
+ if parsed.path:
128
+ path_segments = parsed.path.split('/')
129
+ if len(path_segments) > 4:
130
+ reasons.append(f"Suspicious: Deep URL structure ({len(path_segments)} levels)")
131
+
132
+ # Check for suspicious file extensions
133
+ if any(segment.endswith(('.exe', '.dll', '.bat', '.sh')) for segment in path_segments):
134
+ reasons.append("High Risk: Contains executable file extension")
135
+
136
+ # 5. Query Parameter Analysis
137
+ if parsed.query:
138
+ query_params = parsed.query.split('&')
139
+ suspicious_params = [p for p in query_params if any(k in p.lower() for k in ['pass', 'pwd', 'token', 'key'])]
140
+ if suspicious_params:
141
+ reasons.append("Suspicious: Query contains sensitive parameter names")
142
+
143
+ # 6. Special Pattern Detection
144
+ if len(re.findall(r'[.-]', domain)) > 4:
145
+ reasons.append("Suspicious: Excessive use of dots/hyphens in domain")
146
+
147
+ if re.search(r'([a-zA-Z0-9])\1{3,}', domain):
148
+ reasons.append("Suspicious: Repeated character pattern detected")
149
+
150
+ # 7. TLD Analysis
151
+ if ext.suffix in self.suspicious_tlds:
152
+ reasons.append(f"Suspicious: Known high-risk TLD (.{ext.suffix})")
153
+ elif ext.suffix not in [tld.strip('.') for tld in self.legitimate_tlds]:
154
+ reasons.append(f"Suspicious: Uncommon TLD (.{ext.suffix})")
155
+
156
+ # 8. Keyword Analysis
157
+ found_keywords = []
158
+ for keyword, category in self.suspicious_keywords.items():
159
+ if keyword in f"{domain}{parsed.path}".lower():
160
+ found_keywords.append(f"{keyword} ({category})")
161
+
162
+ if found_keywords:
163
+ reasons.append(f"Suspicious: Contains sensitive keywords: {', '.join(found_keywords)}")
164
+
165
+ return reasons
166
+
167
+ def _batch_preprocess(self, urls):
168
+ processed = []
169
+ for url in urls:
170
+ url = url.strip().lower()
171
+ if not url.startswith(('http://', 'https://')):
172
+ url = f'http://{url}'
173
+ processed.append(url)
174
+ return processed
175
+
176
+ def _batch_tokenize(self, urls):
177
+ return self.tokenizer(
178
+ urls,
179
+ max_length=self.model_expected_length,
180
+ truncation=True,
181
+ padding="max_length",
182
+ return_tensors="np"
183
+ )
184
+
185
+ def _predict_thread(self, urls):
186
+ """Process a batch of URLs in a separate thread"""
187
+ processed_urls = self._batch_preprocess(urls)
188
+ inputs = self._batch_tokenize(processed_urls)
189
+
190
+ ort_inputs = {
191
+ "input_ids": inputs["input_ids"].astype(np.int64),
192
+ "attention_mask": inputs["attention_mask"].astype(np.int64)
193
+ }
194
+
195
+ try:
196
+ logits = self.session.run(None, ort_inputs)[0]
197
+ probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
198
+
199
+ results = []
200
+ for url, prob in zip(urls, probabilities[:, 1]):
201
+ ext = self.extract(url)
202
+ reasons = self._analyze_url_structure(url, ext)
203
+
204
+ if prob > 0.99:
205
+ reasons.append(f"Critical: ML model detected strong phishing patterns (confidence: {prob:.2%})")
206
+ verdict = "phishing"
207
+ else:
208
+ if not reasons:
209
+ reasons = ["No suspicious patterns detected"]
210
+ verdict = "legitimate"
211
+
212
+ result = {
213
+ "url": url,
214
+ "verdict": verdict,
215
+ "confidence": float(prob),
216
+ "reasons": set(reasons)
217
+ }
218
+
219
+ self.url_cache[url] = result
220
+ results.append(result)
221
+
222
+ return results
223
+ except Exception as e:
224
+ # Fallback to rule-based analysis if model inference fails
225
+ results = []
226
+ for url in urls:
227
+ ext = self.extract(url)
228
+ reasons = self._analyze_url_structure(url, ext)
229
+
230
+ # Determine verdict based on rule analysis only
231
+ if any("High Risk" in reason for reason in reasons):
232
+ verdict = "phishing"
233
+ confidence = 0.95
234
+ elif len(reasons) > 2:
235
+ verdict = "phishing"
236
+ confidence = 0.85
237
+ else:
238
+ verdict = "legitimate"
239
+ confidence = 0.70
240
+ if not reasons:
241
+ reasons = ["No suspicious patterns detected"]
242
+
243
+ result = {
244
+ "url": url,
245
+ "verdict": verdict,
246
+ "confidence": float(confidence),
247
+ "reasons": set(reasons + ["Note: Using rule-based analysis due to model inference error"])
248
+ }
249
+
250
+ self.url_cache[url] = result
251
+ results.append(result)
252
+
253
+ return results
254
+
255
+ async def _batch_predict(self, inputs):
256
+ ort_inputs = {
257
+ "input_ids": inputs["input_ids"].astype(np.int64),
258
+ "attention_mask": inputs["attention_mask"].astype(np.int64)
259
+ }
260
+ return self.session.run(None, ort_inputs)[0]
261
+
262
+ async def _batch_analyze(self, urls):
263
+ processed_urls = self._batch_preprocess(urls)
264
+ inputs = self._batch_tokenize(processed_urls)
265
+ logits = await self._batch_predict(inputs)
266
+ probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
267
+ return probabilities[:, 1]
268
+
269
+ async def analyze_batch(self, urls):
270
+ results = []
271
+ uncached_urls = []
272
+
273
+ # Check cache first
274
+ for url in urls:
275
+ if url in self.url_cache:
276
+ results.append(self.url_cache[url])
277
+ else:
278
+ uncached_urls.append(url)
279
+
280
+ if uncached_urls:
281
+ # Split URLs into smaller batches for multithreaded processing
282
+ batch_size = 10 # Process 10 URLs per thread
283
+ url_batches = [uncached_urls[i:i+batch_size] for i in range(0, len(uncached_urls), batch_size)]
284
+
285
+ # Submit each batch to thread pool
286
+ futures = []
287
+ for batch in url_batches:
288
+ futures.append(self.thread_pool.submit(self._predict_thread, batch))
289
+
290
+ # Collect results from all threads
291
+ for future in futures:
292
+ try:
293
+ batch_results = future.result()
294
+ results.extend(batch_results)
295
+ except Exception as e:
296
+ # Handle any unexpected errors in thread execution
297
+ print(f"Error processing batch: {str(e)}")
298
+ # Create fallback results for this batch
299
+ for url in batch:
300
+ results.append({
301
+ "url": url,
302
+ "verdict": "error",
303
+ "confidence": 0.0,
304
+ "reasons": {f"Error analyzing URL: {str(e)}"}
305
+ })
306
+
307
+ return results
308
+
309
+ app = FastAPI()
310
+
311
+ # Add CORS middleware
312
+ app.add_middleware(
313
+ CORSMiddleware,
314
+ allow_origins=["*"],
315
+ allow_credentials=True,
316
+ allow_methods=["*"],
317
+ allow_headers=["*"],
318
+ )
319
+
320
+ detector = ONNXPhishingDetector()
321
+
322
+ class UrlList(BaseModel):
323
+ urls: list[str]
324
+
325
+ @app.post("/scan")
326
+ async def scan_urls(url_list: UrlList):
327
+ start_time = time.time()
328
+
329
+ # Process all URLs in a single batch with internal multithreading
330
+ results = await detector.analyze_batch(url_list.urls)
331
+
332
+ phishing_count = sum(1 for r in results if r["verdict"] == "phishing")
333
+ avg_confidence = sum(r["confidence"] for r in results) / len(results) if results else 0
334
+
335
+ if avg_confidence >= 0.99:
336
+ overall_verdict = "malicious"
337
+ else:
338
+ overall_verdict = "safe"
339
+
340
+ return {
341
+ "time_taken": f"{time.time() - start_time:.2f}s",
342
+ "total_urls": len(url_list.urls),
343
+ "legitimate": len(url_list.urls) - phishing_count,
344
+ "phishing": phishing_count,
345
+ "overall_verdict": overall_verdict,
346
+ "average_confidence": avg_confidence,
347
+ "results": results
348
+ }
349
+
350
+ # To run this file in terminal:
351
+ # 1. Make sure you have all dependencies installed:
352
+ # pip install fastapi uvicorn onnxruntime numpy transformers tldextract
353
+ # 2. Navigate to the directory containing this file
354
+ # 3. Run the command:
355
+ # python -m backend.main
356
+ # or if you're already in the backend directory:
357
+ # python main.py
358
+ # 4. The API will be available at http://localhost:8000
359
+ # 5. You can test it with curl:
360
+ # curl -X POST "http://localhost:8000/scan" -H "Content-Type: application/json" -d '{"urls":["google.com", "suspicious-phishing-site.xyz"]}'
361
+ # 6. Or use tools like Postman to send POST requests to the /scan endpoint
362
+
363
+ if __name__ == "__main__":
364
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ onnxruntime
4
+ scikit-learn
5
+ joblib