import streamlit as st import numpy as np import random from gist1.vqvae_gpt import VQVAETransformer from utils.misc import load_params from utils.isoutil import plot_isovist_sequence_grid import torch if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model_paths = ["./models/vqvaegpt_1.pth", "./models/vqvaegpt_2.pth", "./models/vqvaegpt_3.pth"] cfg_path = "./models/param.json" cfg = load_params(cfg_path) @st.cache_resource def get_model(index): TransformerPath = model_paths[index] transformer = VQVAETransformer(cfg) transformer.load_state_dict(torch.load(TransformerPath, map_location=device)) transformer = transformer.to(device) transformer.eval() return transformer def split_indices(indices, loc_len=1, isovist_len=16): seg_length = loc_len + isovist_len batch_size = indices.shape[0] splits = indices.reshape(batch_size, -1, seg_length) # BS(L+I) ilocs, iisovists = torch.split(splits, [loc_len, isovist_len], dim=2) # BSL , BSI return ilocs, iisovists @st.cache_data def indices_to_loc(_model, indices): indices = torch.tensor(indices).long().view(1,-1).to(device) return _model.indices_to_loc(indices).detach().cpu().numpy() @st.cache_data def indices_to_isovist(_model, indices): indices = torch.tensor(indices).long().view(1,-1).to(device) return _model.z_to_isovist(indices).detach().cpu().numpy() def indices_to_loc_isovist(model, indices): ilocs, iisovists = split_indices(indices, loc_len=1, isovist_len=16) locs = [] sampled_isovists = [] for i in range(iisovists.shape[1]): # iloc = ilocs[:, i, :] # locs.append(model.indices_to_loc(iloc).detach().cpu().numpy()) # S X BL # iisovist = iisovists[:, i, :] # BI # sampled_isovists.append(model.z_to_isovist(iisovist).detach().cpu().numpy()) # S X BCW iloc = ilocs[:, i, :].squeeze().tolist() iisovist = iisovists[:, i, :].squeeze().tolist() iisovist = tuple(iisovist) locs.append(indices_to_loc(model, iloc)) sampled_isovists.append(indices_to_isovist(model, iisovist)) # sampled_isovists.append(code_to_isovist(model, iisovist)) locs = np.stack(locs, axis=1) sampled_isovists = np.stack(sampled_isovists, axis=1) #BSCW return locs, sampled_isovists def plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim): loc = locs[0] sampled_isovist = sampled_isovists[0] sampled_isovist = np.squeeze(sampled_isovist, axis=1) fig = plot_isovist_sequence_grid(loc, sampled_isovist, figsize=(8, 6), center=True, lim=lim, alpha=alpha, calculate_lim=calculate_lim).transpose((1, 2, 0)) return fig def sample(model, start_indices, top_k=100, seed=0, seq_length=None, zeroing=False, lim=1.5, alpha=0.02, loc_init=False, calculate_lim=False): start_indices = start_indices.long().to(device) steps = seq_length * (1 + 16) # loc dim + latent if loc_init: steps -= 1 sample_indices = model.sample_memorized(start_indices, steps=steps, top_k=top_k, seed=seed, zeroing=zeroing) locs, sampled_isovists = indices_to_loc_isovist(model, sample_indices) im = plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim) return im, sample_indices def plot_indices(model, indices, lim=1.5, alpha=0.02, calculate_lim=False): locs, sampled_isovists = indices_to_loc_isovist(model, indices) im = plot_isovist(locs, sampled_isovists, lim, alpha, calculate_lim) return im st.write('''''', unsafe_allow_html=True) st.subheader("GIsT: Generative Isovist Transformers") st.text("Mikhael Johanes, Jeffrey Huang | EPFL Media and Design Lab") st.write("[[paper](https://papers.cumincad.org/data/works/att/ecaade2023_392.pdf)]") st.text("Pres [init] to initiate or start over") options =["Base model", "Palladio", "Mies"] if 'model' not in st.session_state: st.session_state.model = None if st.session_state.model is not None: index = options.index(st.session_state.model) else: index = 0 option = st.selectbox("Select model",(options), index=index) st.session_state.model = option if 'tokens' not in st.session_state: st.session_state.tokens = None if 'image' not in st.session_state: st.session_state.image = np.ones((600,800,3),dtype=np.uint8) * 240 if 'seed' not in st.session_state: st.session_state.seed = random.randint(0, 10000000) index = options.index(st.session_state.model) transformer = get_model(index) e = 1025 ne = 1026 n = 1027 nw = 1028 w = 1029 sw = 1030 s = 1031 se = 1032 alpha = 0.015 lim = 2.0 init = st.button('init') cont = st.container() rows = [] for i in range(3): rows.append(st.columns(3, gap='small')) upleft = rows[0][0].button('$\\nwarrow$', use_container_width=True) up = rows[0][1].button('$\\uparrow$', use_container_width=True) upright = rows[0][2].button('$\\nearrow$', use_container_width=True) left = rows[1][0].button('$\\leftarrow$', use_container_width=True) undo = rows[1][1].button('undo', use_container_width=True) right = rows[1][2].button('$\\rightarrow$', use_container_width=True) downleft = rows[2][0].button('$\\swarrow$', use_container_width=True) down = rows[2][1].button('$\\downarrow$', use_container_width=True) downright = rows[2][2].button('$\\searrow$', use_container_width=True) # st.text("use desktop mode for best experiece in mobile device") seed = st.number_input('seed', 0, 10000000, st.session_state.seed,1) def gen_next(sample_indices, dir): # seed = st.session_state.seed sample_indices = torch.concat([sample_indices, torch.tensor([[dir]]).to(device)],dim=1) im, sample_indices = sample(transformer, sample_indices, top_k=50, seq_length=1, seed=seed, lim=lim, alpha=alpha, loc_init=True, calculate_lim=True) return im, sample_indices def undo_gen(sample_indices): sample_indices = sample_indices[:, :-17] im = plot_indices(transformer, sample_indices, lim=lim,alpha=alpha, calculate_lim=True) return im, sample_indices if init: st.session_state.tokens = torch.ones((1, 1)).long().to(device) * 1024 tokens = st.session_state.tokens # seed = st.session_state.seed im, sample_indices = sample(transformer, tokens, top_k=50, seq_length=1, seed=seed, lim=lim, alpha=alpha, loc_init=True) st.session_state.image = im st.session_state.tokens = sample_indices st.session_state.lim = 2.0 if upleft: if st.session_state.tokens is not None: st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, nw) else: st.warning('Please init the generation') if up: if st.session_state.tokens is not None: st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, n) else: st.warning('Please init the generation') if upright: if st.session_state.tokens is not None: st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, ne) else: st.warning('Please init the generation') if left: if st.session_state.tokens is not None: st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, w) else: st.warning('Please init the generation') if right: if st.session_state.tokens is not None: st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, e) else: st.warning('Please init the generation') if downleft: if st.session_state.tokens is not None: st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, sw) else: st.warning('Please init the generation') if down: if st.session_state.tokens is not None: st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, s) else: st.warning('Please init the generation') if downright: if st.session_state.tokens is not None: st.session_state.image, st.session_state.tokens = gen_next(st.session_state.tokens, se) else: st.warning('Please init the generation') if undo: if st.session_state.tokens is not None: if st.session_state.tokens.shape[1] >= 34: st.session_state.image, st.session_state.tokens = undo_gen(st.session_state.tokens) else: st.warning('no more step to undo') else: st.warning('Please init the generation') cont.image(st.session_state.image)