gist_demo / app.py
Mikhael Johanes
update
bef892f
import streamlit as st
import numpy as np
import random
from gist1.vqvae_gpt import VQVAETransformer
from utils.misc import load_params
from utils.isoutil import plot_isovist_sequence_grid
import torch
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model_paths = ["./models/vqvaegpt_1.pth",
"./models/vqvaegpt_2.pth",
"./models/vqvaegpt_3.pth"]
cfg_path = "./models/param.json"
cfg = load_params(cfg_path)
@st.cache_resource
def get_model(index):
TransformerPath = model_paths[index]
transformer = VQVAETransformer(cfg)
transformer.load_state_dict(torch.load(TransformerPath, map_location=device))
transformer = transformer.to(device)
transformer.eval()
return transformer
def split_indices(indices, loc_len=1, isovist_len=16):
seg_length = loc_len + isovist_len
batch_size = indices.shape[0]
splits = indices.reshape(batch_size, -1, seg_length) # BS(L+I)
ilocs, iisovists = torch.split(splits, [loc_len, isovist_len], dim=2) # BSL , BSI
return ilocs, iisovists
@st.cache_data
def indices_to_loc(_model, indices):
indices = torch.tensor(indices).long().view(1,-1).to(device)
return _model.indices_to_loc(indices).detach().cpu().numpy()
@st.cache_data
def indices_to_isovist(_model, indices):
indices = torch.tensor(indices).long().view(1,-1).to(device)
return _model.z_to_isovist(indices).detach().cpu().numpy()
def indices_to_loc_isovist(model, indices):
ilocs, iisovists = split_indices(indices, loc_len=1, isovist_len=16)
locs = []
sampled_isovists = []
for i in range(iisovists.shape[1]):
# iloc = ilocs[:, i, :]
# locs.append(model.indices_to_loc(iloc).detach().cpu().numpy()) # S X BL
# iisovist = iisovists[:, i, :] # BI
# sampled_isovists.append(model.z_to_isovist(iisovist).detach().cpu().numpy()) # S X BCW
iloc = ilocs[:, i, :].squeeze().tolist()
iisovist = iisovists[:, i, :].squeeze().tolist()
iisovist = tuple(iisovist)
locs.append(indices_to_loc(model, iloc))
sampled_isovists.append(indices_to_isovist(model, iisovist))
# sampled_isovists.append(code_to_isovist(model, iisovist))
locs = np.stack(locs, axis=1)
sampled_isovists = np.stack(sampled_isovists, axis=1) #BSCW
return locs, sampled_isovists
def plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim):
loc = locs[0]
sampled_isovist = sampled_isovists[0]
sampled_isovist = np.squeeze(sampled_isovist, axis=1)
fig = plot_isovist_sequence_grid(loc, sampled_isovist, figsize=(8, 6), center=True, lim=lim, alpha=alpha, calculate_lim=calculate_lim).transpose((1, 2, 0))
return fig
def sample(model, start_indices, top_k=100, seed=0, seq_length=None, zeroing=False, lim=1.5, alpha=0.02, loc_init=False, calculate_lim=False):
start_indices = start_indices.long().to(device)
steps = seq_length * (1 + 16) # loc dim + latent
if loc_init:
steps -= 1
sample_indices = model.sample_memorized(start_indices, steps=steps, top_k=top_k, seed=seed, zeroing=zeroing)
locs, sampled_isovists = indices_to_loc_isovist(model, sample_indices)
im = plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim)
return im, sample_indices
def plot_indices(model, indices, lim=1.5, alpha=0.02, calculate_lim=False):
locs, sampled_isovists = indices_to_loc_isovist(model, indices)
im = plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim)
return im
st.write('''<style>
[data-testid="column"] {
width: calc(33.3333% - 1rem) !important;
flex: 1 1 calc(33.3333% - 1rem) !important;
min-width: calc(33% - 1rem) !important;
}
</style>''', unsafe_allow_html=True)
st.subheader("GIsT: Generative Isovist Transformers")
st.text("Mikhael Johanes, Jeffrey Huang | EPFL Media and Design Lab")
st.write("[[paper](https://papers.cumincad.org/data/works/att/ecaade2023_392.pdf)]")
st.text("Pres [init] to initiate or start over")
options =["Base model", "Palladio", "Mies"]
if 'model' not in st.session_state:
st.session_state.model = None
if st.session_state.model is not None:
index = options.index(st.session_state.model)
else:
index = 0
option = st.selectbox("Select model",(options), index=index)
st.session_state.model = option
if 'tokens' not in st.session_state:
st.session_state.tokens = None
if 'image' not in st.session_state:
st.session_state.image = np.ones((600,800,3),dtype=np.uint8) * 240
if 'seed' not in st.session_state:
st.session_state.seed = random.randint(0, 10000000)
index = options.index(st.session_state.model)
transformer = get_model(index)
e = 1025
ne = 1026
n = 1027
nw = 1028
w = 1029
sw = 1030
s = 1031
se = 1032
alpha = 0.015
lim = 2.0
init = st.button('init')
cont = st.container()
rows = []
for i in range(3):
rows.append(st.columns(3, gap='small'))
upleft = rows[0][0].button('$\\nwarrow$', use_container_width=True)
up = rows[0][1].button('$\\uparrow$', use_container_width=True)
upright = rows[0][2].button('$\\nearrow$', use_container_width=True)
left = rows[1][0].button('$\\leftarrow$', use_container_width=True)
undo = rows[1][1].button('undo', use_container_width=True)
right = rows[1][2].button('$\\rightarrow$', use_container_width=True)
downleft = rows[2][0].button('$\\swarrow$', use_container_width=True)
down = rows[2][1].button('$\\downarrow$', use_container_width=True)
downright = rows[2][2].button('$\\searrow$', use_container_width=True)
# st.text("use desktop mode for best experiece in mobile device")
seed = st.number_input('seed', 0, 10000000, st.session_state.seed,1)
def gen_next(sample_indices, dir):
# seed = st.session_state.seed
sample_indices = torch.concat([sample_indices, torch.tensor([[dir]]).to(device)],dim=1)
im, sample_indices = sample(transformer, sample_indices, top_k=50, seq_length=1, seed=seed, lim=lim, alpha=alpha, loc_init=True, calculate_lim=True)
return im, sample_indices
def undo_gen(sample_indices):
sample_indices = sample_indices[:, :-17]
im = plot_indices(transformer, sample_indices, lim=lim,alpha=alpha, calculate_lim=True)
return im, sample_indices
if init:
st.session_state.tokens = torch.ones((1, 1)).long().to(device) * 1024
tokens = st.session_state.tokens
# seed = st.session_state.seed
im, sample_indices = sample(transformer, tokens, top_k=50, seq_length=1, seed=seed, lim=lim, alpha=alpha, loc_init=True)
st.session_state.image = im
st.session_state.tokens = sample_indices
st.session_state.lim = 2.0
if upleft:
if st.session_state.tokens is not None:
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, nw)
else:
st.warning('Please init the generation')
if up:
if st.session_state.tokens is not None:
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, n)
else:
st.warning('Please init the generation')
if upright:
if st.session_state.tokens is not None:
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, ne)
else:
st.warning('Please init the generation')
if left:
if st.session_state.tokens is not None:
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, w)
else:
st.warning('Please init the generation')
if right:
if st.session_state.tokens is not None:
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, e)
else:
st.warning('Please init the generation')
if downleft:
if st.session_state.tokens is not None:
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, sw)
else:
st.warning('Please init the generation')
if down:
if st.session_state.tokens is not None:
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, s)
else:
st.warning('Please init the generation')
if downright:
if st.session_state.tokens is not None:
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, se)
else:
st.warning('Please init the generation')
if undo:
if st.session_state.tokens is not None:
if st.session_state.tokens.shape[1] >= 34:
st.session_state.image, st.session_state.tokens = undo_gen(st.session_state.tokens)
else:
st.warning('no more step to undo')
else:
st.warning('Please init the generation')
cont.image(st.session_state.image)