File size: 2,569 Bytes
b0b9200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gradio as gr

# Define the VAE model
class ConvVAE(nn.Module):
    def __init__(self, input_channels=3, latent_dim=16):
        super(ConvVAE, self).__init__()
        self.latent_dim = latent_dim
        self.enc_conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
        self.enc_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.enc_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc_mu = nn.Linear(5120, latent_dim)
        self.fc_logvar = nn.Linear(5120, latent_dim)
        self.fc_decode = nn.Linear(latent_dim, 5120)
        self.dec_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1)
        self.dec_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec_conv3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=(0,1))

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = F.relu(self.enc_conv1(x))
        x = F.relu(self.enc_conv2(x))
        x = F.relu(self.enc_conv3(x))
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z)

    def decode(self, z):
        x = F.relu(self.fc_decode(z))
        x = x.view(x.size(0), 128, 4, 10)
        x = F.relu(self.dec_conv1(x))
        x = F.relu(self.dec_conv2(x))
        x = self.dec_conv3(x)
        return F.softmax(x, dim=1)

# Load model
model = ConvVAE()
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
model.eval()

def generate_map(seed: int = None):
    if seed:
        torch.manual_seed(seed)
    z = torch.randn(1, model.latent_dim)
    with torch.no_grad():
        output = model.decode(z)  # Shape: (1, 3, 15, 40)
    output = output.squeeze(0).argmax(dim=0)
    grid = output.cpu().numpy()
    padded_grid = np.vstack([np.zeros((5, grid.shape[1]), dtype=int), grid])  # Append 5 rows of zeros
    return ["".join(map(str, row)) for row in padded_grid]  # Convert each row to a string

gr.Interface(
    fn=generate_map,
    inputs=gr.Number(label="Seed"),
    outputs=gr.JSON(label="Generated Map Grid"),
    title="VAE Level Generator",
    description="Returns a 20x40 grid as a list of strings where 0=air, 1=ground, 2=lava"
).launch()