File size: 1,071 Bytes
42f2be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import gradio as gr
import tensorflow as tf
import numpy as np
from pathlib import Path
from PIL import Image

LATENT_DIM = 100
MODEL_FILE = Path(__file__).with_name("generator_full.keras")

_gen = None  # cargamos “lazy” para arrancar rápido
def get_generator():
    global _gen
    if _gen is None:
        _gen = tf.keras.models.load_model(MODEL_FILE, compile=False)
    return _gen

def generate(digit: int):
    z   = tf.random.normal([5, LATENT_DIM])
    lbl = tf.constant([[digit]] * 5)
    imgs = (get_generator()([z, lbl], training=False) + 1) / 2  # [0,1]
    return [
        Image.fromarray((img.numpy() * 255).astype("uint8").squeeze(), mode="L")
        for img in imgs
    ]

demo = gr.Interface(
    fn=generate,
    inputs=gr.Number(label="Digit 0-9", precision=0, value=4),
    outputs=gr.Gallery(label="Five samples", columns=5, rows=1),
    title="Hand-written Digit Generator (cGAN · 20 epochs)",
    description="Pick a digit and get five MNIST-style images."
)

if __name__ == "__main__":
    demo.launch()