from flask import Flask, request, jsonify import torch from transformers import RobertaTokenizer, RobertaForSequenceClassification import os import gc from functools import lru_cache app = Flask(__name__) model = None tokenizer = None device = None def setup_device(): if torch.cuda.is_available(): return torch.device('cuda') elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return torch.device('mps') else: return torch.device('cpu') def load_tokenizer(): try: tokenizer = RobertaTokenizer.from_pretrained('./tokenizer_vulnerability') tokenizer.model_max_length = 512 return tokenizer except Exception as e: print(f"Error loading tokenizer: {e}") try: return RobertaTokenizer.from_pretrained('microsoft/codebert-base') except Exception as e2: print(f"Fallback tokenizer failed: {e2}") return None def load_model(): global device device = setup_device() print(f"Using device: {device}") try: checkpoint = torch.load("codebert_vulnerability_scorer.pth", map_location=device) if 'config' in checkpoint: from transformers import RobertaConfig config = RobertaConfig.from_dict(checkpoint['config']) model = RobertaForSequenceClassification(config) else: model = RobertaForSequenceClassification.from_pretrained( 'microsoft/codebert-base', num_labels=1 ) if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model.to(device) model.eval() if device.type == 'cuda': model.half() return model except Exception as e: print(f"Error loading model: {e}") return None def cleanup_gpu_memory(): if device and device.type == 'cuda': torch.cuda.empty_cache() gc.collect() try: print("Loading tokenizer...") tokenizer = load_tokenizer() if tokenizer: print("Tokenizer loaded successfully!") else: print("Failed to load tokenizer!") print("Loading model...") model = load_model() if model: print("Model loaded successfully!") else: print("Failed to load model!") except Exception as e: print(f"Error during initialization: {str(e)}") tokenizer = None model = None @app.route("/", methods=['GET']) def home(): return jsonify({ "message": "CodeBERT Vulnerability Evalutor API", "status": "Model loaded" if model is not None else "Model not loaded", "device": str(device) if device else "unknown", "endpoints": { "/predict": "POST with JSON body containing 'codes' array" } }) @app.route("/predict", methods=['POST']) def predict_batch(): try: if model is None or tokenizer is None: return jsonify({"error": "Model not loaded properly"}), 500 data = request.get_json() if not data or 'codes' not in data: return jsonify({"error": "Missing 'codes' field in JSON body"}), 400 codes = data['codes'] if not isinstance(codes, list) or len(codes) == 0: return jsonify({"error": "'codes' must be a non-empty array"}), 400 if len(codes) > 100: return jsonify({"error": "Too many codes. Maximum 100 allowed."}), 400 validated_codes = [] for i, code in enumerate(codes): if not isinstance(code, str): return jsonify({"error": f"Code at index {i} must be a string"}), 400 if len(code.strip()) == 0: validated_codes.append("# empty code") elif len(code) > 50000: return jsonify({"error": f"Code at index {i} too long. Maximum 50000 characters."}), 400 else: validated_codes.append(code.strip()) if len(validated_codes) == 1: score = predict_vulnerability_with_chunking(validated_codes[0]) cleanup_gpu_memory() return jsonify({"results": [{"score": 1.0 - score}]}) batch_size = min(len(validated_codes), 16) results = [] try: for i in range(0, len(validated_codes), batch_size): batch = validated_codes[i:i+batch_size] long_codes = [] short_codes = [] long_indices = [] short_indices = [] for idx, code in enumerate(batch): try: tokens = tokenizer.encode(code, add_special_tokens=False, max_length=1000, truncation=True) if len(tokens) > 450: long_codes.append(code) long_indices.append(i + idx) else: short_codes.append(code) short_indices.append(i + idx) except Exception as e: print(f"Tokenization error for code {i + idx}: {e}") short_codes.append(code) short_indices.append(i + idx) batch_scores = [0.0] * len(batch) if short_codes: try: short_scores = predict_vulnerability_batch(short_codes) for j, score in enumerate(short_scores): local_idx = short_indices[j] - i batch_scores[local_idx] = score except Exception as e: print(f"Batch prediction error: {e}") for j in range(len(short_codes)): local_idx = short_indices[j] - i batch_scores[local_idx] = 0.0 for j, code in enumerate(long_codes): try: score = predict_vulnerability_with_chunking(code) local_idx = long_indices[j] - i batch_scores[local_idx] = score except Exception as e: print(f"Chunking prediction error: {e}") local_idx = long_indices[j] - i batch_scores[local_idx] = 0.0 for score in batch_scores: results.append({"score": round(1.0 - score,4)}) cleanup_gpu_memory() except Exception as e: cleanup_gpu_memory() raise e return jsonify({"results": results}) except Exception as e: cleanup_gpu_memory() return jsonify({"error": f"Batch prediction error: {str(e)}"}), 500 def predict_vulnerability_with_chunking(code): try: if not code or len(code.strip()) == 0: return 0.0 tokens = tokenizer.encode(code, add_special_tokens=False, max_length=2000, truncation=True) if len(tokens) <= 450: return predict_vulnerability(code) chunk_size = 400 overlap = 50 max_score = 0.0 for start in range(0, len(tokens), chunk_size - overlap): end = min(start + chunk_size, len(tokens)) chunk_tokens = tokens[start:end] try: chunk_code = tokenizer.decode(chunk_tokens, skip_special_tokens=True) if chunk_code.strip(): score = predict_vulnerability(chunk_code) max_score = max(max_score, score) except Exception as e: print(f"Chunk processing error: {e}") continue if end >= len(tokens): break return max_score except Exception as e: print(f"Chunking error: {e}") return 0.0 def predict_vulnerability(code): try: if not code or len(code.strip()) == 0: return 0.0 dynamic_length = min(max(len(code.split()) * 2, 128), 512) inputs = tokenizer( code, truncation=True, padding='max_length', max_length=dynamic_length, return_tensors='pt' ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): if device.type == 'cuda': with torch.cuda.amp.autocast(): outputs = model(**inputs) else: outputs = model(**inputs) amplified_logits = 2.0 * outputs.logits score = torch.sigmoid(amplified_logits).cpu().item() return round(max(0.0, min(1.0, score)), 4) except Exception as e: print(f"Single prediction error: {e}") return 0.0 def predict_vulnerability_batch(codes): try: if not codes or len(codes) == 0: return [] filtered_codes = [code if code and code.strip() else "# empty" for code in codes] max_len = max([len(code.split()) * 2 for code in filtered_codes if code]) dynamic_length = min(max(max_len, 128), 512) inputs = tokenizer( filtered_codes, truncation=True, padding='max_length', max_length=dynamic_length, return_tensors='pt' ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): if device.type == 'cuda': with torch.cuda.amp.autocast(): outputs = model(**inputs) else: outputs = model(**inputs) amplified_logits = 2.0 * outputs.logits scores = torch.sigmoid(amplified_logits).cpu().numpy() return [round(max(0.0, min(1.0, float(score))), 4) for score in scores.flatten()] except Exception as e: print(f"Batch prediction error: {e}") return [0.0] * len(codes) @app.route("/health", methods=['GET']) def health_check(): return jsonify({ "status": "healthy", "model_loaded": model is not None, "tokenizer_loaded": tokenizer is not None, "device": str(device) if device else "unknown" }) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=False, threaded=True)