tumor_model / api.py
agcaabdurrahim's picture
Upload folder using huggingface_hub
647ebc1 verified
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import io
import matplotlib
matplotlib.use('Agg') # Set the backend to Agg before importing pyplot
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tensorflow.keras.models import load_model
import tempfile
import random
import os
app = FastAPI()
app.add_middleware(
CORSMiddleware,
#todo: change to allow only the frontend domain
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
model = load_model('my_model.keras')
def predict_and_plot(img):
class_dict = {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3}
label = list(class_dict.keys())
plt.figure(figsize=(16, 12)) # Increased width from 12 to 16
resized_img = img.resize((299, 299))
img_array = np.asarray(resized_img)
if len(img_array.shape) == 2:
img_array = np.stack((img_array,) * 3, axis=-1)
elif img_array.shape[2] == 4:
img_array = img_array[:, :, :3]
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array / 255.0
predictions = model.predict(img_array)
probs = list(predictions[0])
# Get the highest probability prediction
max_prob_idx = np.argmax(probs)
prediction_text = f"Prediction: {label[max_prob_idx]} ({probs[max_prob_idx]:.2%})"
plt.subplot(2, 1, 1)
plt.imshow(resized_img)
plt.title('Input Image', fontsize=16, pad=20)
plt.axis('off')
plt.subplot(2, 1, 2)
bars = plt.barh(label, probs)
plt.xlabel('Probability', fontsize=14)
plt.title('Prediction Probabilities', fontsize=16, pad=20)
ax = plt.gca()
ax.bar_label(bars, fmt='%.2f', fontsize=12)
plt.xlim(0, 1.1) # Set x-axis limit to accommodate labels
plt.tight_layout() # Adjust layout to prevent label cutoff
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
plt.savefig(temp_file.name, dpi=300, bbox_inches='tight')
plt.close()
return temp_file.name, prediction_text
def get_random_image():
# Set a random seed based on current time
random.seed()
print(f"Random seed: {random.getstate()[1][0]}")
# Get the absolute path of the Testing directory
base_dir = os.path.abspath(os.path.dirname(__file__))
testing_dir = os.path.join(base_dir, 'Testing')
print(f"Testing directory: {testing_dir}")
# Get all subdirectories in Testing
subdirs = [d for d in os.listdir(testing_dir) if os.path.isdir(os.path.join(testing_dir, d))]
print(f"Available subdirectories: {subdirs}")
# Randomly select a subdirectory
random_subdir = random.choice(subdirs)
print(f"Selected subdirectory: {random_subdir}")
# Get all images in the selected subdirectory
subdir_path = os.path.join(testing_dir, random_subdir)
images = [f for f in os.listdir(subdir_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
print(f"Found {len(images)} images in {random_subdir}")
print(f"First few images: {images[:5]}")
if not images:
raise Exception(f"No images found in {random_subdir}")
# Randomly select an image
random_image = random.choice(images)
print(f"Selected image: {random_image}")
# Return full path
full_path = os.path.join(subdir_path, random_image)
print(f"Full path: {full_path}")
return full_path
@app.get("/get-random-image")
async def get_random_image_endpoint():
try:
random_image_path = get_random_image()
if not os.path.exists(random_image_path):
raise Exception(f"Image file not found: {random_image_path}")
return FileResponse(
random_image_path,
media_type="image/png",
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, OPTIONS",
"Access-Control-Allow-Headers": "*",
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0"
}
)
except Exception as e:
print(f"Error getting random image: {str(e)}")
raise
@app.post("/predict")
async def predict_image(file: UploadFile = File(...)):
try:
contents = await file.read()
img = Image.open(io.BytesIO(contents))
result_path, prediction_text = predict_and_plot(img)
return FileResponse(
result_path,
media_type="image/png",
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
}
)
except Exception as e:
print(f"Error processing image: {str(e)}")
raise
@app.post("/predict-text")
async def predict_text(file: UploadFile = File(...)):
try:
contents = await file.read()
img = Image.open(io.BytesIO(contents))
resized_img = img.resize((299, 299))
img_array = np.asarray(resized_img)
if len(img_array.shape) == 2:
img_array = np.stack((img_array,) * 3, axis=-1)
elif img_array.shape[2] == 4:
img_array = img_array[:, :, :3]
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array / 255.0
predictions = model.predict(img_array)
probs = list(predictions[0])
class_dict = {'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3}
label = list(class_dict.keys())
max_prob_idx = np.argmax(probs)
prediction_text = f"Prediction: {label[max_prob_idx]} ({probs[max_prob_idx]:.2%})"
return {"prediction": prediction_text}
except Exception as e:
print(f"Error processing image: {str(e)}")
raise
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)