|
import streamlit as st |
|
import numpy as np |
|
import random |
|
|
|
from gist1.vqvae_gpt import VQVAETransformer |
|
from utils.misc import load_params |
|
from utils.isoutil import plot_isovist_sequence_grid |
|
|
|
|
|
import torch |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
|
|
model_paths = ["./models/vqvaegpt_1.pth", |
|
"./models/vqvaegpt_2.pth", |
|
"./models/vqvaegpt_3.pth"] |
|
cfg_path = "./models/param.json" |
|
cfg = load_params(cfg_path) |
|
|
|
|
|
|
|
@st.cache_resource |
|
def get_model(index): |
|
TransformerPath = model_paths[index] |
|
transformer = VQVAETransformer(cfg) |
|
transformer.load_state_dict(torch.load(TransformerPath, map_location=device)) |
|
transformer = transformer.to(device) |
|
transformer.eval() |
|
return transformer |
|
|
|
|
|
def split_indices(indices, loc_len=1, isovist_len=16): |
|
seg_length = loc_len + isovist_len |
|
batch_size = indices.shape[0] |
|
splits = indices.reshape(batch_size, -1, seg_length) |
|
ilocs, iisovists = torch.split(splits, [loc_len, isovist_len], dim=2) |
|
return ilocs, iisovists |
|
|
|
@st.cache_data |
|
def indices_to_loc(_model, indices): |
|
indices = torch.tensor(indices).long().view(1,-1).to(device) |
|
return _model.indices_to_loc(indices).detach().cpu().numpy() |
|
|
|
@st.cache_data |
|
def indices_to_isovist(_model, indices): |
|
indices = torch.tensor(indices).long().view(1,-1).to(device) |
|
return _model.z_to_isovist(indices).detach().cpu().numpy() |
|
|
|
def indices_to_loc_isovist(model, indices): |
|
ilocs, iisovists = split_indices(indices, loc_len=1, isovist_len=16) |
|
locs = [] |
|
sampled_isovists = [] |
|
for i in range(iisovists.shape[1]): |
|
|
|
|
|
|
|
|
|
|
|
iloc = ilocs[:, i, :].squeeze().tolist() |
|
iisovist = iisovists[:, i, :].squeeze().tolist() |
|
iisovist = tuple(iisovist) |
|
locs.append(indices_to_loc(model, iloc)) |
|
sampled_isovists.append(indices_to_isovist(model, iisovist)) |
|
|
|
|
|
locs = np.stack(locs, axis=1) |
|
sampled_isovists = np.stack(sampled_isovists, axis=1) |
|
return locs, sampled_isovists |
|
|
|
def plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim): |
|
loc = locs[0] |
|
sampled_isovist = sampled_isovists[0] |
|
sampled_isovist = np.squeeze(sampled_isovist, axis=1) |
|
fig = plot_isovist_sequence_grid(loc, sampled_isovist, figsize=(8, 6), center=True, lim=lim, alpha=alpha, calculate_lim=calculate_lim).transpose((1, 2, 0)) |
|
return fig |
|
|
|
def sample(model, start_indices, top_k=100, seed=0, seq_length=None, zeroing=False, lim=1.5, alpha=0.02, loc_init=False, calculate_lim=False): |
|
start_indices = start_indices.long().to(device) |
|
steps = seq_length * (1 + 16) |
|
if loc_init: |
|
steps -= 1 |
|
sample_indices = model.sample_memorized(start_indices, steps=steps, top_k=top_k, seed=seed, zeroing=zeroing) |
|
locs, sampled_isovists = indices_to_loc_isovist(model, sample_indices) |
|
im = plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim) |
|
return im, sample_indices |
|
|
|
|
|
def plot_indices(model, indices, lim=1.5, alpha=0.02, calculate_lim=False): |
|
locs, sampled_isovists = indices_to_loc_isovist(model, indices) |
|
im = plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim) |
|
return im |
|
|
|
st.write('''<style> |
|
|
|
[data-testid="column"] { |
|
width: calc(33.3333% - 1rem) !important; |
|
flex: 1 1 calc(33.3333% - 1rem) !important; |
|
min-width: calc(33% - 1rem) !important; |
|
} |
|
</style>''', unsafe_allow_html=True) |
|
|
|
|
|
st.subheader("GIsT: Generative Isovist Transformers") |
|
st.text("Mikhael Johanes, Jeffrey Huang | EPFL Media and Design Lab") |
|
st.write("[[paper](https://papers.cumincad.org/data/works/att/ecaade2023_392.pdf)]") |
|
st.text("Pres [init] to initiate or start over") |
|
options =["Base model", "Palladio", "Mies"] |
|
|
|
if 'model' not in st.session_state: |
|
st.session_state.model = None |
|
|
|
if st.session_state.model is not None: |
|
index = options.index(st.session_state.model) |
|
else: |
|
index = 0 |
|
|
|
option = st.selectbox("Select model",(options), index=index) |
|
st.session_state.model = option |
|
|
|
|
|
if 'tokens' not in st.session_state: |
|
st.session_state.tokens = None |
|
|
|
if 'image' not in st.session_state: |
|
st.session_state.image = np.ones((600,800,3),dtype=np.uint8) * 240 |
|
|
|
if 'seed' not in st.session_state: |
|
st.session_state.seed = random.randint(0, 10000000) |
|
|
|
|
|
|
|
index = options.index(st.session_state.model) |
|
transformer = get_model(index) |
|
|
|
|
|
e = 1025 |
|
ne = 1026 |
|
n = 1027 |
|
nw = 1028 |
|
w = 1029 |
|
sw = 1030 |
|
s = 1031 |
|
se = 1032 |
|
|
|
alpha = 0.015 |
|
lim = 2.0 |
|
|
|
init = st.button('init') |
|
|
|
cont = st.container() |
|
|
|
|
|
|
|
|
|
rows = [] |
|
for i in range(3): |
|
rows.append(st.columns(3, gap='small')) |
|
|
|
|
|
|
|
|
|
upleft = rows[0][0].button('$\\nwarrow$', use_container_width=True) |
|
up = rows[0][1].button('$\\uparrow$', use_container_width=True) |
|
upright = rows[0][2].button('$\\nearrow$', use_container_width=True) |
|
left = rows[1][0].button('$\\leftarrow$', use_container_width=True) |
|
undo = rows[1][1].button('undo', use_container_width=True) |
|
right = rows[1][2].button('$\\rightarrow$', use_container_width=True) |
|
downleft = rows[2][0].button('$\\swarrow$', use_container_width=True) |
|
down = rows[2][1].button('$\\downarrow$', use_container_width=True) |
|
downright = rows[2][2].button('$\\searrow$', use_container_width=True) |
|
|
|
|
|
|
|
seed = st.number_input('seed', 0, 10000000, st.session_state.seed,1) |
|
|
|
|
|
def gen_next(sample_indices, dir): |
|
|
|
sample_indices = torch.concat([sample_indices, torch.tensor([[dir]]).to(device)],dim=1) |
|
im, sample_indices = sample(transformer, sample_indices, top_k=50, seq_length=1, seed=seed, lim=lim, alpha=alpha, loc_init=True, calculate_lim=True) |
|
return im, sample_indices |
|
|
|
def undo_gen(sample_indices): |
|
sample_indices = sample_indices[:, :-17] |
|
im = plot_indices(transformer, sample_indices, lim=lim,alpha=alpha, calculate_lim=True) |
|
return im, sample_indices |
|
|
|
if init: |
|
st.session_state.tokens = torch.ones((1, 1)).long().to(device) * 1024 |
|
tokens = st.session_state.tokens |
|
|
|
im, sample_indices = sample(transformer, tokens, top_k=50, seq_length=1, seed=seed, lim=lim, alpha=alpha, loc_init=True) |
|
st.session_state.image = im |
|
st.session_state.tokens = sample_indices |
|
st.session_state.lim = 2.0 |
|
|
|
if upleft: |
|
if st.session_state.tokens is not None: |
|
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, nw) |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
if up: |
|
if st.session_state.tokens is not None: |
|
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, n) |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
if upright: |
|
if st.session_state.tokens is not None: |
|
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, ne) |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
if left: |
|
if st.session_state.tokens is not None: |
|
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, w) |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
if right: |
|
if st.session_state.tokens is not None: |
|
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, e) |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
if downleft: |
|
if st.session_state.tokens is not None: |
|
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, sw) |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
if down: |
|
if st.session_state.tokens is not None: |
|
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, s) |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
if downright: |
|
if st.session_state.tokens is not None: |
|
st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, se) |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
|
|
if undo: |
|
if st.session_state.tokens is not None: |
|
if st.session_state.tokens.shape[1] >= 34: |
|
st.session_state.image, st.session_state.tokens = undo_gen(st.session_state.tokens) |
|
else: |
|
st.warning('no more step to undo') |
|
else: |
|
st.warning('Please init the generation') |
|
|
|
|
|
|
|
cont.image(st.session_state.image) |
|
|