from fastapi import FastAPI, File, UploadFile from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware import io import matplotlib matplotlib.use('Agg') # Set the backend to Agg before importing pyplot import matplotlib.pyplot as plt import numpy as np from PIL import Image from tensorflow.keras.models import load_model import tempfile import random import os app = FastAPI() app.add_middleware( CORSMiddleware, #todo: change to allow only the frontend domain allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["*"] ) model = load_model('my_model.keras') def predict_and_plot(img): class_dict = {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3} label = list(class_dict.keys()) plt.figure(figsize=(16, 12)) # Increased width from 12 to 16 resized_img = img.resize((299, 299)) img_array = np.asarray(resized_img) if len(img_array.shape) == 2: img_array = np.stack((img_array,) * 3, axis=-1) elif img_array.shape[2] == 4: img_array = img_array[:, :, :3] img_array = np.expand_dims(img_array, axis=0) img_array = img_array / 255.0 predictions = model.predict(img_array) probs = list(predictions[0]) # Get the highest probability prediction max_prob_idx = np.argmax(probs) prediction_text = f"Prediction: {label[max_prob_idx]} ({probs[max_prob_idx]:.2%})" plt.subplot(2, 1, 1) plt.imshow(resized_img) plt.title('Input Image', fontsize=16, pad=20) plt.axis('off') plt.subplot(2, 1, 2) bars = plt.barh(label, probs) plt.xlabel('Probability', fontsize=14) plt.title('Prediction Probabilities', fontsize=16, pad=20) ax = plt.gca() ax.bar_label(bars, fmt='%.2f', fontsize=12) plt.xlim(0, 1.1) # Set x-axis limit to accommodate labels plt.tight_layout() # Adjust layout to prevent label cutoff temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') plt.savefig(temp_file.name, dpi=300, bbox_inches='tight') plt.close() return temp_file.name, prediction_text def get_random_image(): # Set a random seed based on current time random.seed() print(f"Random seed: {random.getstate()[1][0]}") # Get the absolute path of the Testing directory base_dir = os.path.abspath(os.path.dirname(__file__)) testing_dir = os.path.join(base_dir, 'Testing') print(f"Testing directory: {testing_dir}") # Get all subdirectories in Testing subdirs = [d for d in os.listdir(testing_dir) if os.path.isdir(os.path.join(testing_dir, d))] print(f"Available subdirectories: {subdirs}") # Randomly select a subdirectory random_subdir = random.choice(subdirs) print(f"Selected subdirectory: {random_subdir}") # Get all images in the selected subdirectory subdir_path = os.path.join(testing_dir, random_subdir) images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] print(f"Found {len(images)} images in {random_subdir}") print(f"First few images: {images[:5]}") if not images: raise Exception(f"No images found in {random_subdir}") # Randomly select an image random_image = random.choice(images) print(f"Selected image: {random_image}") # Return full path full_path = os.path.join(subdir_path, random_image) print(f"Full path: {full_path}") return full_path @app.get("/get-random-image") async def get_random_image_endpoint(): try: random_image_path = get_random_image() if not os.path.exists(random_image_path): raise Exception(f"Image file not found: {random_image_path}") return FileResponse( random_image_path, media_type="image/png", headers={ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, OPTIONS", "Access-Control-Allow-Headers": "*", "Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0" } ) except Exception as e: print(f"Error getting random image: {str(e)}") raise @app.post("/predict") async def predict_image(file: UploadFile = File(...)): try: contents = await file.read() img = Image.open(io.BytesIO(contents)) result_path, prediction_text = predict_and_plot(img) return FileResponse( result_path, media_type="image/png", headers={ "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "*", } ) except Exception as e: print(f"Error processing image: {str(e)}") raise @app.post("/predict-text") async def predict_text(file: UploadFile = File(...)): try: contents = await file.read() img = Image.open(io.BytesIO(contents)) resized_img = img.resize((299, 299)) img_array = np.asarray(resized_img) if len(img_array.shape) == 2: img_array = np.stack((img_array,) * 3, axis=-1) elif img_array.shape[2] == 4: img_array = img_array[:, :, :3] img_array = np.expand_dims(img_array, axis=0) img_array = img_array / 255.0 predictions = model.predict(img_array) probs = list(predictions[0]) class_dict = {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3} label = list(class_dict.keys()) max_prob_idx = np.argmax(probs) prediction_text = f"Prediction: {label[max_prob_idx]} ({probs[max_prob_idx]:.2%})" return {"prediction": prediction_text} except Exception as e: print(f"Error processing image: {str(e)}") raise if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)