testingtesting / app.py
gabo152210's picture
Upload 3 files
42f2be9 verified
import gradio as gr
import tensorflow as tf
import numpy as np
from pathlib import Path
from PIL import Image
LATENT_DIM = 100
MODEL_FILE = Path(__file__).with_name("generator_full.keras")
_gen = None # cargamos “lazy” para arrancar rápido
def get_generator():
global _gen
if _gen is None:
_gen = tf.keras.models.load_model(MODEL_FILE, compile=False)
return _gen
def generate(digit: int):
z = tf.random.normal([5, LATENT_DIM])
lbl = tf.constant([[digit]] * 5)
imgs = (get_generator()([z, lbl], training=False) + 1) / 2 # [0,1]
return [
Image.fromarray((img.numpy() * 255).astype("uint8").squeeze(), mode="L")
for img in imgs
]
demo = gr.Interface(
fn=generate,
inputs=gr.Number(label="Digit 0-9", precision=0, value=4),
outputs=gr.Gallery(label="Five samples", columns=5, rows=1),
title="Hand-written Digit Generator (cGAN · 20 epochs)",
description="Pick a digit and get five MNIST-style images."
)
if __name__ == "__main__":
demo.launch()