|
from fastapi import FastAPI, File, UploadFile |
|
from fastapi.responses import FileResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import io |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
from tensorflow.keras.models import load_model |
|
import tempfile |
|
import random |
|
import os |
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
|
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
expose_headers=["*"] |
|
) |
|
|
|
model = load_model('my_model.keras') |
|
|
|
def predict_and_plot(img): |
|
class_dict = {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3} |
|
label = list(class_dict.keys()) |
|
|
|
plt.figure(figsize=(16, 12)) |
|
|
|
resized_img = img.resize((299, 299)) |
|
img_array = np.asarray(resized_img) |
|
|
|
if len(img_array.shape) == 2: |
|
img_array = np.stack((img_array,) * 3, axis=-1) |
|
elif img_array.shape[2] == 4: |
|
img_array = img_array[:, :, :3] |
|
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
img_array = img_array / 255.0 |
|
|
|
predictions = model.predict(img_array) |
|
probs = list(predictions[0]) |
|
|
|
|
|
max_prob_idx = np.argmax(probs) |
|
prediction_text = f"Prediction: {label[max_prob_idx]} ({probs[max_prob_idx]:.2%})" |
|
|
|
plt.subplot(2, 1, 1) |
|
plt.imshow(resized_img) |
|
plt.title('Input Image', fontsize=16, pad=20) |
|
plt.axis('off') |
|
|
|
plt.subplot(2, 1, 2) |
|
bars = plt.barh(label, probs) |
|
plt.xlabel('Probability', fontsize=14) |
|
plt.title('Prediction Probabilities', fontsize=16, pad=20) |
|
ax = plt.gca() |
|
ax.bar_label(bars, fmt='%.2f', fontsize=12) |
|
plt.xlim(0, 1.1) |
|
|
|
plt.tight_layout() |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
plt.savefig(temp_file.name, dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
return temp_file.name, prediction_text |
|
|
|
|
|
def get_random_image(): |
|
|
|
random.seed() |
|
print(f"Random seed: {random.getstate()[1][0]}") |
|
|
|
|
|
base_dir = os.path.abspath(os.path.dirname(__file__)) |
|
testing_dir = os.path.join(base_dir, 'Testing') |
|
print(f"Testing directory: {testing_dir}") |
|
|
|
|
|
subdirs = [d for d in os.listdir(testing_dir) if os.path.isdir(os.path.join(testing_dir, d))] |
|
print(f"Available subdirectories: {subdirs}") |
|
|
|
|
|
random_subdir = random.choice(subdirs) |
|
print(f"Selected subdirectory: {random_subdir}") |
|
|
|
|
|
subdir_path = os.path.join(testing_dir, random_subdir) |
|
images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] |
|
print(f"Found {len(images)} images in {random_subdir}") |
|
print(f"First few images: {images[:5]}") |
|
|
|
if not images: |
|
raise Exception(f"No images found in {random_subdir}") |
|
|
|
|
|
random_image = random.choice(images) |
|
print(f"Selected image: {random_image}") |
|
|
|
|
|
full_path = os.path.join(subdir_path, random_image) |
|
print(f"Full path: {full_path}") |
|
return full_path |
|
|
|
@app.get("/get-random-image") |
|
async def get_random_image_endpoint(): |
|
try: |
|
random_image_path = get_random_image() |
|
if not os.path.exists(random_image_path): |
|
raise Exception(f"Image file not found: {random_image_path}") |
|
|
|
return FileResponse( |
|
random_image_path, |
|
media_type="image/png", |
|
headers={ |
|
"Access-Control-Allow-Origin": "*", |
|
"Access-Control-Allow-Methods": "GET, OPTIONS", |
|
"Access-Control-Allow-Headers": "*", |
|
"Cache-Control": "no-cache, no-store, must-revalidate", |
|
"Pragma": "no-cache", |
|
"Expires": "0" |
|
} |
|
) |
|
except Exception as e: |
|
print(f"Error getting random image: {str(e)}") |
|
raise |
|
|
|
@app.post("/predict") |
|
async def predict_image(file: UploadFile = File(...)): |
|
try: |
|
contents = await file.read() |
|
img = Image.open(io.BytesIO(contents)) |
|
|
|
result_path, prediction_text = predict_and_plot(img) |
|
|
|
return FileResponse( |
|
result_path, |
|
media_type="image/png", |
|
headers={ |
|
"Access-Control-Allow-Origin": "*", |
|
"Access-Control-Allow-Methods": "POST, OPTIONS", |
|
"Access-Control-Allow-Headers": "*", |
|
} |
|
) |
|
except Exception as e: |
|
print(f"Error processing image: {str(e)}") |
|
raise |
|
|
|
@app.post("/predict-text") |
|
async def predict_text(file: UploadFile = File(...)): |
|
try: |
|
contents = await file.read() |
|
img = Image.open(io.BytesIO(contents)) |
|
|
|
resized_img = img.resize((299, 299)) |
|
img_array = np.asarray(resized_img) |
|
|
|
if len(img_array.shape) == 2: |
|
img_array = np.stack((img_array,) * 3, axis=-1) |
|
elif img_array.shape[2] == 4: |
|
img_array = img_array[:, :, :3] |
|
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
img_array = img_array / 255.0 |
|
|
|
predictions = model.predict(img_array) |
|
probs = list(predictions[0]) |
|
|
|
class_dict = {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3} |
|
label = list(class_dict.keys()) |
|
|
|
max_prob_idx = np.argmax(probs) |
|
prediction_text = f"Prediction: {label[max_prob_idx]} ({probs[max_prob_idx]:.2%})" |
|
|
|
return {"prediction": prediction_text} |
|
except Exception as e: |
|
print(f"Error processing image: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |