CaxtonEmeraldS's picture
Update app.py
4d2375b verified
raw
history blame
2.37 kB
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
import zipfile
import os
from huggingface_hub import hf_hub_download
# Hugging Face repository ID
repo_id = "CaxtonEmeraldS/CholesterolConcentrationPredictor" # Replace with your actual repo name
# Unzip models only once
unzip_dir = "unzipped_models"
if not os.path.exists(unzip_dir):
print("Downloading and extracting model zip file...")
zip_path = hf_hub_download(repo_id=os.path.join('spaces', repo_id), filename="Models.zip") # Replace with your actual uploaded ZIP filename
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(unzip_dir)
print("Extraction complete.")
# Load linear models
# linear_rgb_path = os.path.join(unzip_dir, "linear_models/linear_rgb.joblib")
# linear_grey_path = os.path.join(unzip_dir, "linear_models/linear_grey.joblib")
# linear_rgb = joblib.load(linear_rgb_path)
# linear_grey = joblib.load(linear_grey_path)
def predict(r, g, b, activation, seed, neurons):
try:
X = np.array([[r, g, b]])
# grey = 0.2989 * r + 0.5870 * g + 0.1140 * b
# # Linear predictions
# lin_pred_rgb = linear_rgb.predict(X)[0]
# lin_pred_grey = linear_grey.predict([[grey]])[0]
# Load corresponding ANN model
keras_path = os.path.join(unzip_dir, f"{activation}/seed_{seed}/model_{neurons}.keras")
if not os.path.exists(keras_path):
raise FileNotFoundError(f"Model not found: {keras_path}")
model = tf.keras.models.load_model(keras_path)
ann_pred = model.predict(X)[0][0]
return ann_pred, lin_pred_rgb, lin_pred_grey
except Exception as e:
return f"Error: {str(e)}", "", ""
iface = gr.Interface(
fn=predict,
inputs=[
gr.Number(label="R"),
gr.Number(label="G"),
gr.Number(label="B"),
gr.Textbox(label="Activation (folder name)"),
gr.Number(label="Seed (folder name)"),
gr.Number(label="Neurons (model number)")
],
outputs=[
gr.Text(label="ANN Model Prediction"),
gr.Text(label="Linear RGB Prediction"),
gr.Text(label="Linear Grey Prediction"),
],
title="ANN vs Linear Model Predictor",
description="Dynamically load models from Hugging Face repo and predict."
)
if __name__ == "__main__":
iface.launch()