|
import os |
|
import sys |
|
import io |
|
import tarfile |
|
import torch |
|
import webdataset as wds |
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..')) |
|
from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline |
|
|
|
import datetime |
|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader |
|
import diffusers |
|
import fire |
|
|
|
def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50): |
|
blocks_to_save = [ |
|
'unet.down_blocks.2.attentions.1', |
|
'unet.mid_block.attentions.0', |
|
'unet.up_blocks.0.attentions.0', |
|
'unet.up_blocks.0.attentions.1', |
|
] |
|
|
|
|
|
dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42) |
|
pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo') |
|
pipe.to('cuda') |
|
pipe.set_progress_bar_config(disable=True) |
|
dataloader = DataLoader(dataset, batch_size=dataset_batch_size) |
|
|
|
ct = datetime.datetime.now() |
|
save_path = os.path.join(save_path, str(ct)) |
|
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
writers = { |
|
block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save |
|
} |
|
|
|
writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')}) |
|
|
|
def to_kwargs(kwargs_to_save): |
|
kwargs = kwargs_to_save.copy() |
|
seed = kwargs['seed'] |
|
del kwargs['seed'] |
|
kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document) |
|
return kwargs |
|
|
|
dataloader_iter = iter(dataloader) |
|
for num_document, batch in tqdm(enumerate(dataloader)): |
|
if num_document < start_at: |
|
continue |
|
|
|
if num_document >= finish_at: |
|
break |
|
|
|
kwargs_to_save = { |
|
'prompt': batch['caption'], |
|
'positions_to_cache': blocks_to_save, |
|
'save_input': True, |
|
'save_output': True, |
|
'num_inference_steps': 1, |
|
'guidance_scale': 0.0, |
|
'seed': num_document, |
|
'output_type': 'pil' |
|
} |
|
|
|
kwargs = to_kwargs(kwargs_to_save) |
|
|
|
output, cache = pipe.run_with_cache( |
|
**kwargs |
|
) |
|
|
|
blocks = cache['input'].keys() |
|
for block in blocks: |
|
sample = { |
|
"__key__": f"sample_{num_document}", |
|
"output.pth": cache['output'][block], |
|
"diff.pth": cache['output'][block] - cache['input'][block], |
|
"gen_args.json": kwargs_to_save |
|
} |
|
|
|
writers[block].write(sample) |
|
writers['images'].write({ |
|
"__key__": f"sample_{num_document}", |
|
"images.npy": np.stack(output.images) |
|
}) |
|
|
|
for block, writer in writers.items(): |
|
writer.close() |
|
|
|
if __name__ == '__main__': |
|
fire.Fire(main) |