diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..23f2e229ff2fca94b6d015ba18e41ecc1cdfa9af
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,155 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+#   This is especially recommended for binary packages to ensure reproducibility, and is more
+#   commonly ignored for libraries.
+#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+#   in version control.
+#   https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..0704d550acddc5efeb7a51d156f413a1f791b857
--- /dev/null
+++ b/app.py
@@ -0,0 +1,205 @@
+# Not ready to use yet
+import argparse
+import numpy as np
+import gradio as gr
+from omegaconf import OmegaConf
+import torch
+from PIL import Image
+import PIL
+from pipelines import TwoStagePipeline
+from huggingface_hub import hf_hub_download
+import os
+import rembg
+from typing import Any
+import json
+import os
+import json
+import argparse
+
+from model import CRM
+from inference import generate3d
+
+pipeline = None
+rembg_session = rembg.new_session()
+
+
+def check_input_image(input_image):
+    if input_image is None:
+        raise gr.Error("No image uploaded!")
+
+
+def remove_background(
+    image: PIL.Image.Image,
+    rembg_session: Any = None,
+    force: bool = False,
+    **rembg_kwargs,
+) -> PIL.Image.Image:
+    do_remove = True
+    if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
+        # explain why current do not rm bg
+        print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
+        background = Image.new("RGBA", image.size, (0, 0, 0, 0))
+        image = Image.alpha_composite(background, image)
+        do_remove = False
+    do_remove = do_remove or force
+    if do_remove:
+        image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
+    return image
+
+def do_resize_content(original_image: Image, scale_rate):
+    # resize image content wile retain the original image size
+    if scale_rate != 1:
+        # Calculate the new size after rescaling
+        new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
+        # Resize the image while maintaining the aspect ratio
+        resized_image = original_image.resize(new_size)
+        # Create a new image with the original size and black background
+        padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
+        paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
+        padded_image.paste(resized_image, paste_position)
+        return padded_image
+    else:
+        return original_image
+
+def add_background(image, bg_color=(255, 255, 255)):
+    # given an RGBA image, alpha channel is used as mask to add background color
+    background = Image.new("RGBA", image.size, bg_color)
+    return Image.alpha_composite(background, image)
+
+
+def preprocess_image(input_image, do_remove_background, force_remove, foreground_ratio, backgroud_color):
+    """
+    input image is a pil image in RGBA, return RGB image
+    """
+    if do_remove_background:
+        image = remove_background(input_image, rembg_session, force_remove)
+    image = do_resize_content(image, foreground_ratio)
+    image = add_background(image, backgroud_color)
+    return image.convert("RGB")
+
+
+def gen_image(input_image, seed, scale, step):
+    global pipeline, model, args
+    pipeline.set_seed(seed)
+    rt_dict = pipeline(input_image, scale=scale, step=step)
+    stage1_images = rt_dict["stage1_images"]
+    stage2_images = rt_dict["stage2_images"]
+    np_imgs = np.concatenate(stage1_images, 1)
+    np_xyzs = np.concatenate(stage2_images, 1)
+
+    glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, args.device)
+    return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path, obj_path
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+    "--stage1_config",
+    type=str,
+    default="configs/nf7_v3_SNR_rd_size_stroke.yaml",
+    help="config for stage1",
+)
+parser.add_argument(
+    "--stage2_config",
+    type=str,
+    default="configs/stage2-v2-snr.yaml",
+    help="config for stage2",
+)
+
+parser.add_argument("--device", type=str, default="cuda")
+args = parser.parse_args()
+
+crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
+specs = json.load(open("configs/specs_objaverse_total.json"))
+model = CRM(specs).to(args.device)
+model.load_state_dict(torch.load(crm_path, map_location = args.device), strict=False)
+
+stage1_config = OmegaConf.load(args.stage1_config).config
+stage2_config = OmegaConf.load(args.stage2_config).config
+stage2_sampler_config = stage2_config.sampler
+stage1_sampler_config = stage1_config.sampler
+
+stage1_model_config = stage1_config.models
+stage2_model_config = stage2_config.models
+
+xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
+pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
+stage1_model_config.resume = pixel_path
+stage2_model_config.resume = xyz_path
+
+pipeline = TwoStagePipeline(
+    stage1_model_config,
+    stage2_model_config,
+    stage1_sampler_config,
+    stage2_sampler_config,
+    device=args.device,
+    dtype=torch.float16
+)
+
+with gr.Blocks() as demo:
+    gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model")
+    with gr.Row():
+        with gr.Column():
+            with gr.Row():
+                image_input = gr.Image(
+                    label="Image input",
+                    image_mode="RGBA",
+                    sources="upload",
+                    type="pil",
+                )
+                processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB")
+            with gr.Row():
+                with gr.Column():
+                    with gr.Row():
+                        do_remove_background = gr.Checkbox(label="Remove Background", value=True)
+                        force_remove = gr.Checkbox(label="Force Remove", value=False)
+                    back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
+                    foreground_ratio = gr.Slider(
+                        label="Foreground Ratio",
+                        minimum=0.5,
+                        maximum=1.0,
+                        value=1.0,
+                        step=0.05,
+                    )
+
+                with gr.Column():
+                    seed = gr.Number(value=1234, label="seed", precision=0)
+                    guidance_scale = gr.Number(value=5.5, minimum=0, maximum=20, label="guidance_scale")
+                    step = gr.Number(value=50, minimum=1, maximum=100, label="sample steps", precision=0)
+            text_button = gr.Button("Generate Images")
+        with gr.Column():
+            image_output = gr.Image(interactive=False, label="Output RGB image")
+            xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
+
+            output_model = gr.Model3D(
+                label="Output GLB",
+                interactive=False,
+            )
+            output_obj = gr.File(interactive=False, label="Output OBJ")
+
+    inputs = [
+        processed_image,
+        seed,
+        guidance_scale,
+        step,
+    ]
+    outputs = [
+        image_output,
+        xyz_ouput,
+        output_model,
+        output_obj,
+    ]
+    gr.Examples(
+        examples=[os.path.join("examples", i) for i in os.listdir("examples")],
+        inputs=[image_input],
+    )
+
+    text_button.click(fn=check_input_image, inputs=[image_input]).success(
+        fn=preprocess_image,
+        inputs=[image_input, do_remove_background, force_remove, foreground_ratio, back_groud_color],
+        outputs=[processed_image],
+    ).success(
+        fn=gen_image,
+        inputs=inputs,
+        outputs=outputs,
+    )
+    demo.queue().launch()
diff --git a/configs/nf7_v3_SNR_rd_size_stroke.yaml b/configs/nf7_v3_SNR_rd_size_stroke.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f9f81a412c773d0757819a4323384c33b540da96
--- /dev/null
+++ b/configs/nf7_v3_SNR_rd_size_stroke.yaml
@@ -0,0 +1,21 @@
+config:
+# others
+    seed: 1234
+    num_frames: 7
+    mode: pixel
+    offset_noise: true
+# model related
+    models:
+        config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
+        resume: models/pixel.pth
+# sampler related
+    sampler:
+        target: libs.sample.ImageDreamDiffusion
+        params:
+            mode: pixel
+            num_frames: 7
+            camera_views: [1, 2, 3, 4, 5, 0, 0]
+            ref_position: 6
+            random_background: false
+            offset_noise: true
+            resize_rate: 1.0
\ No newline at end of file
diff --git a/configs/specs_objaverse_total.json b/configs/specs_objaverse_total.json
new file mode 100644
index 0000000000000000000000000000000000000000..d1541e3c7ac3c93264065567eb11a11eddccf359
--- /dev/null
+++ b/configs/specs_objaverse_total.json
@@ -0,0 +1,57 @@
+{
+  "Input": {
+    "img_num": 16,
+    "class": "all",
+    "camera_angle_num": 8,
+    "tet_grid_size": 80,
+    "validate_num": 16,
+    "scale": 0.95,
+    "radius": 3,
+    "resolution": [256, 256]
+  },
+
+  "Pretrain": {
+    "mode": null,
+    "sdf_threshold": 0.1,
+    "sdf_scale": 10,
+    "batch_infer": false,
+    "lr": 1e-4,
+    "radius": 0.5
+  },
+
+  "Train": {
+    "mode": "rnd",
+    "num_epochs": 500,
+    "grad_acc": 1,
+    "warm_up": 0,
+    "decay": 0.000,
+    "learning_rate": {
+      "init": 1e-4,
+      "sdf_decay": 1,
+      "rgb_decay": 1
+    },
+    "batch_size": 4,
+    "eva_iter": 80,
+    "eva_all_epoch": 10,
+    "tex_sup_mode": "blender",
+    "exp_uv_mesh": false,
+    "doub": false,
+    "random_bg": false,
+    "shift": 0,
+    "aug_shift": 0,
+    "geo_type": "flex"
+  },
+
+  "ArchSpecs": {
+    "unet_type": "diffusers",
+    "use_3D_aware": false,
+    "fea_concat": false,
+    "mlp_bias": true
+  },
+
+  "DecoderSpecs": {
+    "c_dim": 32,
+    "plane_resolution": 256
+  }
+}
+
diff --git a/configs/stage2-v2-snr.yaml b/configs/stage2-v2-snr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6dafe186c51fe0aa70739a1bbe983e2a04f42107
--- /dev/null
+++ b/configs/stage2-v2-snr.yaml
@@ -0,0 +1,25 @@
+config:
+# others
+    seed: 1234
+    num_frames: 6
+    mode: pixel
+    offset_noise: true
+    gd_type: xyz
+# model related
+    models:
+        config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
+        resume: models/xyz.pth
+
+# eval related
+    sampler:
+        target: libs.sample.ImageDreamDiffusionStage2
+        params:
+            mode: pixel
+            num_frames: 6
+            camera_views: [1, 2, 3, 4, 5, 0]
+            ref_position: null
+            random_background: false
+            offset_noise: true
+            resize_rate: 1.0
+
+
diff --git a/imagedream/__init__.py b/imagedream/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ba541c412090383da3a7e26113d00c6fe372daa
--- /dev/null
+++ b/imagedream/__init__.py
@@ -0,0 +1 @@
+from .model_zoo import build_model
diff --git a/imagedream/camera_utils.py b/imagedream/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3a5d80755135f1c279ec49e0fd46c64962cde0f
--- /dev/null
+++ b/imagedream/camera_utils.py
@@ -0,0 +1,99 @@
+import numpy as np
+import torch
+
+
+def create_camera_to_world_matrix(elevation, azimuth):
+    elevation = np.radians(elevation)
+    azimuth = np.radians(azimuth)
+    # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
+    x = np.cos(elevation) * np.sin(azimuth)
+    y = np.sin(elevation)
+    z = np.cos(elevation) * np.cos(azimuth)
+
+    # Calculate camera position, target, and up vectors
+    camera_pos = np.array([x, y, z])
+    target = np.array([0, 0, 0])
+    up = np.array([0, 1, 0])
+
+    # Construct view matrix
+    forward = target - camera_pos
+    forward /= np.linalg.norm(forward)
+    right = np.cross(forward, up)
+    right /= np.linalg.norm(right)
+    new_up = np.cross(right, forward)
+    new_up /= np.linalg.norm(new_up)
+    cam2world = np.eye(4)
+    cam2world[:3, :3] = np.array([right, new_up, -forward]).T
+    cam2world[:3, 3] = camera_pos
+    return cam2world
+
+
+def convert_opengl_to_blender(camera_matrix):
+    if isinstance(camera_matrix, np.ndarray):
+        # Construct transformation matrix to convert from OpenGL space to Blender space
+        flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
+        camera_matrix_blender = np.dot(flip_yz, camera_matrix)
+    else:
+        # Construct transformation matrix to convert from OpenGL space to Blender space
+        flip_yz = torch.tensor(
+            [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
+        )
+        if camera_matrix.ndim == 3:
+            flip_yz = flip_yz.unsqueeze(0)
+        camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
+    return camera_matrix_blender
+
+
+def normalize_camera(camera_matrix):
+    """normalize the camera location onto a unit-sphere"""
+    if isinstance(camera_matrix, np.ndarray):
+        camera_matrix = camera_matrix.reshape(-1, 4, 4)
+        translation = camera_matrix[:, :3, 3]
+        translation = translation / (
+            np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8
+        )
+        camera_matrix[:, :3, 3] = translation
+    else:
+        camera_matrix = camera_matrix.reshape(-1, 4, 4)
+        translation = camera_matrix[:, :3, 3]
+        translation = translation / (
+            torch.norm(translation, dim=1, keepdim=True) + 1e-8
+        )
+        camera_matrix[:, :3, 3] = translation
+    return camera_matrix.reshape(-1, 16)
+
+
+def get_camera(
+    num_frames, 
+    elevation=15, 
+    azimuth_start=0, 
+    azimuth_span=360, 
+    blender_coord=True,
+    extra_view=False,
+):
+    angle_gap = azimuth_span / num_frames
+    cameras = []
+    for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
+        camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
+        if blender_coord:
+            camera_matrix = convert_opengl_to_blender(camera_matrix)
+        cameras.append(camera_matrix.flatten())
+        
+    if extra_view:
+        dim = len(cameras[0])
+        cameras.append(np.zeros(dim))  
+    return torch.tensor(np.stack(cameras, 0)).float()
+
+
+def get_camera_for_index(data_index):
+    """
+    按照当前我们的数据格式, 以000为正对我们的情况:
+    000是正面, ev: 0, azimuth: 0
+    001是左边, ev: 0, azimuth: -90
+    002是下面, ev: -90, azimuth: 0
+    003是背面, ev: 0, azimuth: 180
+    004是右边, ev: 0, azimuth: 90
+    005是上面, ev: 90, azimuth: 0
+    """
+    params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)]
+    return get_camera(1, *params[data_index])
\ No newline at end of file
diff --git a/imagedream/configs/sd_v2_base_ipmv.yaml b/imagedream/configs/sd_v2_base_ipmv.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..234fa01d97eba0fbc1c700d920d4dd4947d9b309
--- /dev/null
+++ b/imagedream/configs/sd_v2_base_ipmv.yaml
@@ -0,0 +1,61 @@
+model:
+  target: imagedream.ldm.interface.LatentDiffusionInterface
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    timesteps: 1000
+    scale_factor: 0.18215
+    parameterization: "eps"
+
+    unet_config:
+      target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        use_checkpoint: False
+        legacy: False
+        camera_dim: 16
+        with_ip: True
+        ip_dim: 16 # ip token length
+        ip_mode: "local_resample"
+
+    vae_config:
+      target: imagedream.ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          #attn_type: "vanilla-xformers"
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    clip_config:
+      target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+      params:
+        freeze: True
+        layer: "penultimate"
+        ip_mode: "local_resample"
diff --git a/imagedream/configs/sd_v2_base_ipmv_ch8.yaml b/imagedream/configs/sd_v2_base_ipmv_ch8.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1979885bb2a22db2002cfe52345ffcb29f64addc
--- /dev/null
+++ b/imagedream/configs/sd_v2_base_ipmv_ch8.yaml
@@ -0,0 +1,61 @@
+model:
+  target: imagedream.ldm.interface.LatentDiffusionInterface
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    timesteps: 1000
+    scale_factor: 0.18215
+    parameterization: "eps"
+
+    unet_config:
+      target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 8
+        out_channels: 8
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        use_checkpoint: False
+        legacy: False
+        camera_dim: 16
+        with_ip: True
+        ip_dim: 16 # ip token length
+        ip_mode: "local_resample"
+
+    vae_config:
+      target: imagedream.ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          #attn_type: "vanilla-xformers"
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    clip_config:
+      target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+      params:
+        freeze: True
+        layer: "penultimate"
+        ip_mode: "local_resample"
diff --git a/imagedream/configs/sd_v2_base_ipmv_chin8.yaml b/imagedream/configs/sd_v2_base_ipmv_chin8.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b6010ded905a3233d0474491c374f45ff6649503
--- /dev/null
+++ b/imagedream/configs/sd_v2_base_ipmv_chin8.yaml
@@ -0,0 +1,61 @@
+model:
+  target: imagedream.ldm.interface.LatentDiffusionInterface
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    timesteps: 1000
+    scale_factor: 0.18215
+    parameterization: "eps"
+
+    unet_config:
+      target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
+      params:
+        image_size: 32 # unused
+        in_channels: 8
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        use_checkpoint: False
+        legacy: False
+        camera_dim: 16
+        with_ip: True
+        ip_dim: 16 # ip token length
+        ip_mode: "local_resample"
+
+    vae_config:
+      target: imagedream.ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          #attn_type: "vanilla-xformers"
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    clip_config:
+      target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+      params:
+        freeze: True
+        layer: "penultimate"
+        ip_mode: "local_resample"
diff --git a/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml b/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fa9e1c8fc87dfef6c06ef2db3160d828ab95ae2f
--- /dev/null
+++ b/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
@@ -0,0 +1,62 @@
+model:
+  target: imagedream.ldm.interface.LatentDiffusionInterface
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    timesteps: 1000
+    scale_factor: 0.18215
+    parameterization: "eps"
+    zero_snr: true
+
+    unet_config:
+      target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
+      params:
+        image_size: 32 # unused
+        in_channels: 8
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        use_checkpoint: False
+        legacy: False
+        camera_dim: 16
+        with_ip: True
+        ip_dim: 16 # ip token length
+        ip_mode: "local_resample"
+
+    vae_config:
+      target: imagedream.ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          #attn_type: "vanilla-xformers"
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    clip_config:
+      target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+      params:
+        freeze: True
+        layer: "penultimate"
+        ip_mode: "local_resample"
diff --git a/imagedream/configs/sd_v2_base_ipmv_local.yaml b/imagedream/configs/sd_v2_base_ipmv_local.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f50e72870272346632d91849723a49221317ee30
--- /dev/null
+++ b/imagedream/configs/sd_v2_base_ipmv_local.yaml
@@ -0,0 +1,62 @@
+model:
+  target: imagedream.ldm.interface.LatentDiffusionInterface
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    timesteps: 1000
+    scale_factor: 0.18215
+    parameterization: "eps"
+
+    unet_config:
+      target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        use_checkpoint: False
+        legacy: False
+        camera_dim: 16
+        with_ip: True
+        ip_dim: 16 # ip token length
+        ip_mode: "local_resample"
+        ip_weight: 1.0 # adjust for similarity to image 
+
+    vae_config:
+      target: imagedream.ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          #attn_type: "vanilla-xformers"
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    clip_config:
+      target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+      params:
+        freeze: True
+        layer: "penultimate"
+        ip_mode: "local_resample"
diff --git a/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml b/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d0115e086043c4dca06dfbf1a0fd03b4f4b6bec8
--- /dev/null
+++ b/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
@@ -0,0 +1,62 @@
+model:
+  target: imagedream.ldm.interface.LatentDiffusionInterface
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    timesteps: 1000
+    scale_factor: 0.18215
+    parameterization: "eps"
+    zero_snr: true
+
+    unet_config:
+      target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        use_checkpoint: False
+        legacy: False
+        camera_dim: 16
+        with_ip: True
+        ip_dim: 16 # ip token length
+        ip_mode: "local_resample"
+
+    vae_config:
+      target: imagedream.ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          #attn_type: "vanilla-xformers"
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    clip_config:
+      target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+      params:
+        freeze: True
+        layer: "penultimate"
+        ip_mode: "local_resample"
diff --git a/imagedream/ldm/__init__.py b/imagedream/ldm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imagedream/ldm/interface.py b/imagedream/ldm/interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d417c895489ba514008fd11bfc18571897113a
--- /dev/null
+++ b/imagedream/ldm/interface.py
@@ -0,0 +1,205 @@
+from typing import List
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .modules.diffusionmodules.util import (
+    make_beta_schedule,
+    extract_into_tensor,
+    enforce_zero_terminal_snr,
+    noise_like,
+)
+from .util import exists, default, instantiate_from_config
+from .modules.distributions.distributions import DiagonalGaussianDistribution
+
+
+class DiffusionWrapper(nn.Module):
+    def __init__(self, diffusion_model):
+        super().__init__()
+        self.diffusion_model = diffusion_model
+
+    def forward(self, *args, **kwargs):
+        return self.diffusion_model(*args, **kwargs)
+
+
+class LatentDiffusionInterface(nn.Module):
+    """a simple interface class for LDM inference"""
+
+    def __init__(
+        self,
+        unet_config,
+        clip_config,
+        vae_config,
+        parameterization="eps",
+        scale_factor=0.18215,
+        beta_schedule="linear",
+        timesteps=1000,
+        linear_start=0.00085,
+        linear_end=0.0120,
+        cosine_s=8e-3,
+        given_betas=None,
+        zero_snr=False,
+        *args,
+        **kwargs,
+    ):
+        super().__init__()
+
+        unet = instantiate_from_config(unet_config)
+        self.model = DiffusionWrapper(unet)
+        self.clip_model = instantiate_from_config(clip_config)
+        self.vae_model = instantiate_from_config(vae_config)
+
+        self.parameterization = parameterization
+        self.scale_factor = scale_factor
+        self.register_schedule(
+            given_betas=given_betas,
+            beta_schedule=beta_schedule,
+            timesteps=timesteps,
+            linear_start=linear_start,
+            linear_end=linear_end,
+            cosine_s=cosine_s,
+            zero_snr=zero_snr
+        )
+
+    def register_schedule(
+        self,
+        given_betas=None,
+        beta_schedule="linear",
+        timesteps=1000,
+        linear_start=1e-4,
+        linear_end=2e-2,
+        cosine_s=8e-3,
+        zero_snr=False
+    ):
+        if exists(given_betas):
+            betas = given_betas
+        else:
+            betas = make_beta_schedule(
+                beta_schedule,
+                timesteps,
+                linear_start=linear_start,
+                linear_end=linear_end,
+                cosine_s=cosine_s,
+            )
+        if zero_snr:
+            print("--- using zero snr---")
+            betas = enforce_zero_terminal_snr(betas).numpy()
+        alphas = 1.0 - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+        (timesteps,) = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert (
+            alphas_cumprod.shape[0] == self.num_timesteps
+        ), "alphas have to be defined for each timestep"
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer("betas", to_torch(betas))
+        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+        self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer(
+            "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+        )
+        self.register_buffer(
+            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+        )
+        self.register_buffer(
+            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+        )
+        self.register_buffer(
+            "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+        )
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        self.v_posterior = 0
+        posterior_variance = (1 - self.v_posterior) * betas * (
+            1.0 - alphas_cumprod_prev
+        ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer("posterior_variance", to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer(
+            "posterior_log_variance_clipped",
+            to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
+        )
+        self.register_buffer(
+            "posterior_mean_coef1",
+            to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
+        )
+        self.register_buffer(
+            "posterior_mean_coef2",
+            to_torch(
+                (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+            ),
+        )
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+            * noise
+        )
+
+    def get_v(self, x, noise, t):
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
+            - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+        )
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+            - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+            * noise
+        )
+
+    def predict_start_from_z_and_v(self, x_t, t, v):
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
+            - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+        )
+
+    def predict_eps_from_z_and_v(self, x_t, t, v):
+        return (
+            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
+            * x_t
+        )
+
+    def apply_model(self, x_noisy, t, cond, **kwargs):
+        assert isinstance(cond, dict), "cond has to be a dictionary"
+        return self.model(x_noisy, t, **cond, **kwargs)
+
+    def get_learned_conditioning(self, prompts: List[str]):
+        return self.clip_model(prompts)
+    
+    def get_learned_image_conditioning(self, images):
+        return self.clip_model.forward_image(images)
+
+    def get_first_stage_encoding(self, encoder_posterior):
+        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+            z = encoder_posterior.sample()
+        elif isinstance(encoder_posterior, torch.Tensor):
+            z = encoder_posterior
+        else:
+            raise NotImplementedError(
+                f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+            )
+        return self.scale_factor * z
+
+    def encode_first_stage(self, x):
+        return self.vae_model.encode(x)
+
+    def decode_first_stage(self, z):
+        z = 1.0 / self.scale_factor * z
+        return self.vae_model.decode(z)
diff --git a/imagedream/ldm/models/__init__.py b/imagedream/ldm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imagedream/ldm/models/autoencoder.py b/imagedream/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8941041399e08f538c5e32f523eba54889c5bd4c
--- /dev/null
+++ b/imagedream/ldm/models/autoencoder.py
@@ -0,0 +1,270 @@
+import torch
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from ..modules.diffusionmodules.model import Encoder, Decoder
+from ..modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ..util import instantiate_from_config
+from ..modules.ema import LitEma
+
+
+class AutoencoderKL(torch.nn.Module):
+    def __init__(
+        self,
+        ddconfig,
+        lossconfig,
+        embed_dim,
+        ckpt_path=None,
+        ignore_keys=[],
+        image_key="image",
+        colorize_nlabels=None,
+        monitor=None,
+        ema_decay=None,
+        learn_logvar=False,
+    ):
+        super().__init__()
+        self.learn_logvar = learn_logvar
+        self.image_key = image_key
+        self.encoder = Encoder(**ddconfig)
+        self.decoder = Decoder(**ddconfig)
+        self.loss = instantiate_from_config(lossconfig)
+        assert ddconfig["double_z"]
+        self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+        self.embed_dim = embed_dim
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels) == int
+            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+
+        self.use_ema = ema_decay is not None
+        if self.use_ema:
+            self.ema_decay = ema_decay
+            assert 0.0 < ema_decay < 1.0
+            self.model_ema = LitEma(self, decay=ema_decay)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.parameters())
+            self.model_ema.copy_to(self)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self)
+
+    def encode(self, x):
+        h = self.encoder(x)
+        moments = self.quant_conv(h)
+        posterior = DiagonalGaussianDistribution(moments)
+        return posterior
+
+    def decode(self, z):
+        z = self.post_quant_conv(z)
+        dec = self.decoder(z)
+        return dec
+
+    def forward(self, input, sample_posterior=True):
+        posterior = self.encode(input)
+        if sample_posterior:
+            z = posterior.sample()
+        else:
+            z = posterior.mode()
+        dec = self.decode(z)
+        return dec, posterior
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+
+        if optimizer_idx == 0:
+            # train encoder+decoder+logvar
+            aeloss, log_dict_ae = self.loss(
+                inputs,
+                reconstructions,
+                posterior,
+                optimizer_idx,
+                self.global_step,
+                last_layer=self.get_last_layer(),
+                split="train",
+            )
+            self.log(
+                "aeloss",
+                aeloss,
+                prog_bar=True,
+                logger=True,
+                on_step=True,
+                on_epoch=True,
+            )
+            self.log_dict(
+                log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
+            )
+            return aeloss
+
+        if optimizer_idx == 1:
+            # train the discriminator
+            discloss, log_dict_disc = self.loss(
+                inputs,
+                reconstructions,
+                posterior,
+                optimizer_idx,
+                self.global_step,
+                last_layer=self.get_last_layer(),
+                split="train",
+            )
+
+            self.log(
+                "discloss",
+                discloss,
+                prog_bar=True,
+                logger=True,
+                on_step=True,
+                on_epoch=True,
+            )
+            self.log_dict(
+                log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
+            )
+            return discloss
+
+    def validation_step(self, batch, batch_idx):
+        log_dict = self._validation_step(batch, batch_idx)
+        with self.ema_scope():
+            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+        return log_dict
+
+    def _validation_step(self, batch, batch_idx, postfix=""):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+        aeloss, log_dict_ae = self.loss(
+            inputs,
+            reconstructions,
+            posterior,
+            0,
+            self.global_step,
+            last_layer=self.get_last_layer(),
+            split="val" + postfix,
+        )
+
+        discloss, log_dict_disc = self.loss(
+            inputs,
+            reconstructions,
+            posterior,
+            1,
+            self.global_step,
+            last_layer=self.get_last_layer(),
+            split="val" + postfix,
+        )
+
+        self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+        self.log_dict(log_dict_ae)
+        self.log_dict(log_dict_disc)
+        return self.log_dict
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        ae_params_list = (
+            list(self.encoder.parameters())
+            + list(self.decoder.parameters())
+            + list(self.quant_conv.parameters())
+            + list(self.post_quant_conv.parameters())
+        )
+        if self.learn_logvar:
+            print(f"{self.__class__.__name__}: Learning logvar")
+            ae_params_list.append(self.loss.logvar)
+        opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
+        opt_disc = torch.optim.Adam(
+            self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
+        )
+        return [opt_ae, opt_disc], []
+
+    def get_last_layer(self):
+        return self.decoder.conv_out.weight
+
+    @torch.no_grad()
+    def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.image_key)
+        x = x.to(self.device)
+        if not only_inputs:
+            xrec, posterior = self(x)
+            if x.shape[1] > 3:
+                # colorize with random projection
+                assert xrec.shape[1] > 3
+                x = self.to_rgb(x)
+                xrec = self.to_rgb(xrec)
+            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+            log["reconstructions"] = xrec
+            if log_ema or self.use_ema:
+                with self.ema_scope():
+                    xrec_ema, posterior_ema = self(x)
+                    if x.shape[1] > 3:
+                        # colorize with random projection
+                        assert xrec_ema.shape[1] > 3
+                        xrec_ema = self.to_rgb(xrec_ema)
+                    log["samples_ema"] = self.decode(
+                        torch.randn_like(posterior_ema.sample())
+                    )
+                    log["reconstructions_ema"] = xrec_ema
+        log["inputs"] = x
+        return log
+
+    def to_rgb(self, x):
+        assert self.image_key == "segmentation"
+        if not hasattr(self, "colorize"):
+            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+        x = F.conv2d(x, weight=self.colorize)
+        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
+        return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+    def __init__(self, *args, vq_interface=False, **kwargs):
+        self.vq_interface = vq_interface
+        super().__init__()
+
+    def encode(self, x, *args, **kwargs):
+        return x
+
+    def decode(self, x, *args, **kwargs):
+        return x
+
+    def quantize(self, x, *args, **kwargs):
+        if self.vq_interface:
+            return x, None, [None, None, None]
+        return x
+
+    def forward(self, x, *args, **kwargs):
+        return x
diff --git a/imagedream/ldm/models/diffusion/__init__.py b/imagedream/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imagedream/ldm/models/diffusion/ddim.py b/imagedream/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5e1372f2ec75b16d3c35ca7fa7111321f8a4ade
--- /dev/null
+++ b/imagedream/ldm/models/diffusion/ddim.py
@@ -0,0 +1,430 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ...modules.diffusionmodules.util import (
+    make_ddim_sampling_parameters,
+    make_ddim_timesteps,
+    noise_like,
+    extract_into_tensor,
+)
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(
+        self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+    ):
+        self.ddim_timesteps = make_ddim_timesteps(
+            ddim_discr_method=ddim_discretize,
+            num_ddim_timesteps=ddim_num_steps,
+            num_ddpm_timesteps=self.ddpm_num_timesteps,
+            verbose=verbose,
+        )
+        alphas_cumprod = self.model.alphas_cumprod
+        assert (
+            alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+        ), "alphas have to be defined for each timestep"
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer("betas", to_torch(self.model.betas))
+        self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+        self.register_buffer(
+            "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+        )
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer(
+            "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_one_minus_alphas_cumprod",
+            to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+        )
+        self.register_buffer(
+            "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+        )
+        self.register_buffer(
+            "sqrt_recipm1_alphas_cumprod",
+            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+        )
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+            alphacums=alphas_cumprod.cpu(),
+            ddim_timesteps=self.ddim_timesteps,
+            eta=ddim_eta,
+            verbose=verbose,
+        )
+        self.register_buffer("ddim_sigmas", ddim_sigmas)
+        self.register_buffer("ddim_alphas", ddim_alphas)
+        self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+        self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev)
+            / (1 - self.alphas_cumprod)
+            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+        )
+        self.register_buffer(
+            "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+        )
+
+    @torch.no_grad()
+    def sample(
+        self,
+        S,
+        batch_size,
+        shape,
+        conditioning=None,
+        callback=None,
+        normals_sequence=None,
+        img_callback=None,
+        quantize_x0=False,
+        eta=0.0,
+        mask=None,
+        x0=None,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        verbose=True,
+        x_T=None,
+        log_every_t=100,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+        **kwargs,
+    ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(
+                        f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+                    )
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(
+                        f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+                    )
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+
+        samples, intermediates = self.ddim_sampling(
+            conditioning,
+            size,
+            callback=callback,
+            img_callback=img_callback,
+            quantize_denoised=quantize_x0,
+            mask=mask,
+            x0=x0,
+            ddim_use_original_steps=False,
+            noise_dropout=noise_dropout,
+            temperature=temperature,
+            score_corrector=score_corrector,
+            corrector_kwargs=corrector_kwargs,
+            x_T=x_T,
+            log_every_t=log_every_t,
+            unconditional_guidance_scale=unconditional_guidance_scale,
+            unconditional_conditioning=unconditional_conditioning,
+            **kwargs,
+        )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(
+        self,
+        cond,
+        shape,
+        x_T=None,
+        ddim_use_original_steps=False,
+        callback=None,
+        timesteps=None,
+        quantize_denoised=False,
+        mask=None,
+        x0=None,
+        img_callback=None,
+        log_every_t=100,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        **kwargs,
+    ):
+        """
+        when inference time: all values of parameter
+        cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
+        shape: (5, 4, 32, 32)
+        x_T: None
+        ddim_use_original_steps: False
+        timesteps: None
+        callback: None
+        quantize_denoised: False
+        mask: None
+        image_callback: None
+        log_every_t: 100
+        temperature: 1.0
+        noise_dropout: 0.0
+        score_corrector: None
+        corrector_kwargs: None
+        unconditional_guidance_scale: 5
+        unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
+        kwargs: {}
+        """
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94
+        else:
+            img = x_T
+
+        if timesteps is None: # equal with set time step in hf
+            timesteps = (
+                self.ddpm_num_timesteps
+                if ddim_use_original_steps
+                else self.ddim_timesteps
+            )
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = (
+                int(
+                    min(timesteps / self.ddim_timesteps.shape[0], 1)
+                    * self.ddim_timesteps.shape[0]
+                )
+                - 1
+            )
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {"x_inter": [img], "pred_x0": [img]}
+        time_range = ( # reversed timesteps
+            reversed(range(0, timesteps))
+            if ddim_use_original_steps
+            else np.flip(timesteps)
+        )
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(
+                    x0, ts
+                )  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1.0 - mask) * img
+
+            outs = self.p_sample_ddim(
+                img,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=ddim_use_original_steps,
+                quantize_denoised=quantize_denoised,
+                temperature=temperature,
+                noise_dropout=noise_dropout,
+                score_corrector=score_corrector,
+                corrector_kwargs=corrector_kwargs,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+                **kwargs,
+            )
+            img, pred_x0 = outs
+            if callback:
+                callback(i)
+            if img_callback:
+                img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates["x_inter"].append(img)
+                intermediates["pred_x0"].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(
+        self,
+        x,
+        c,
+        t,
+        index,
+        repeat_noise=False,
+        use_original_steps=False,
+        quantize_denoised=False,
+        temperature=1.0,
+        noise_dropout=0.0,
+        score_corrector=None,
+        corrector_kwargs=None,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        dynamic_threshold=None,
+        **kwargs,
+    ):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            x_in = torch.cat([x] * 2)
+            t_in = torch.cat([t] * 2)
+            if isinstance(c, dict):
+                assert isinstance(unconditional_conditioning, dict)
+                c_in = dict()
+                for k in c:
+                    if isinstance(c[k], list):
+                        c_in[k] = [
+                            torch.cat([unconditional_conditioning[k][i], c[k][i]])
+                            for i in range(len(c[k]))
+                        ]
+                    elif isinstance(c[k], torch.Tensor):
+                        c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
+                    else:
+                        assert c[k] == unconditional_conditioning[k]
+                        c_in[k] = c[k]
+            elif isinstance(c, list):
+                c_in = list()
+                assert isinstance(unconditional_conditioning, list)
+                for i in range(len(c)):
+                    c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+            else:
+                c_in = torch.cat([unconditional_conditioning, c])
+            model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+            model_output = model_uncond + unconditional_guidance_scale * (
+                model_t - model_uncond
+            )
+
+
+        if self.model.parameterization == "v":
+            print("using v!")
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", "not implemented"
+            e_t = score_corrector.modify_score(
+                self.model, e_t, x, t, c, **corrector_kwargs
+            )
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = (
+            self.model.alphas_cumprod_prev
+            if use_original_steps
+            else self.ddim_alphas_prev
+        )
+        sqrt_one_minus_alphas = (
+            self.model.sqrt_one_minus_alphas_cumprod
+            if use_original_steps
+            else self.ddim_sqrt_one_minus_alphas
+        )
+        sigmas = (
+            self.model.ddim_sigmas_for_original_num_steps
+            if use_original_steps
+            else self.ddim_sigmas
+        )
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full(
+            (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+        )
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.0:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (
+            extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+            + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
+        )
+
+    @torch.no_grad()
+    def decode(
+        self,
+        x_latent,
+        cond,
+        t_start,
+        unconditional_guidance_scale=1.0,
+        unconditional_conditioning=None,
+        use_original_steps=False,
+        **kwargs,
+    ):
+        timesteps = (
+            np.arange(self.ddpm_num_timesteps)
+            if use_original_steps
+            else self.ddim_timesteps
+        )
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+
+        iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full(
+                (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
+            )
+            x_dec, _ = self.p_sample_ddim(
+                x_dec,
+                cond,
+                ts,
+                index=index,
+                use_original_steps=use_original_steps,
+                unconditional_guidance_scale=unconditional_guidance_scale,
+                unconditional_conditioning=unconditional_conditioning,
+                **kwargs,
+            )
+        return x_dec
diff --git a/imagedream/ldm/modules/__init__.py b/imagedream/ldm/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imagedream/ldm/modules/attention.py b/imagedream/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0693788a7dc44170aaee8bc90c13c4dfee13e4f9
--- /dev/null
+++ b/imagedream/ldm/modules/attention.py
@@ -0,0 +1,456 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from .diffusionmodules.util import checkpoint
+
+
+try:
+    import xformers
+    import xformers.ops
+
+    XFORMERS_IS_AVAILBLE = True
+except:
+    XFORMERS_IS_AVAILBLE = False
+
+# CrossAttn precision handling
+import os
+
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+
+def exists(val):
+    return val is not None
+
+
+def uniq(arr):
+    return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+    return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+    dim = tensor.shape[-1]
+    std = 1 / math.sqrt(dim)
+    tensor.uniform_(-std, std)
+    return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def forward(self, x):
+        x, gate = self.proj(x).chunk(2, dim=-1)
+        return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+        super().__init__()
+        inner_dim = int(dim * mult)
+        dim_out = default(dim_out, dim)
+        project_in = (
+            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+            if not glu
+            else GEGLU(dim, inner_dim)
+        )
+
+        self.net = nn.Sequential(
+            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def Normalize(in_channels):
+    return torch.nn.GroupNorm(
+        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+    )
+
+
+class SpatialSelfAttention(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.k = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.v = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.proj_out = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b, c, h, w = q.shape
+        q = rearrange(q, "b c h w -> b (h w) c")
+        k = rearrange(k, "b c h w -> b c (h w)")
+        w_ = torch.einsum("bij,bjk->bik", q, k)
+
+        w_ = w_ * (int(c) ** (-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = rearrange(v, "b c h w -> b c (h w)")
+        w_ = rearrange(w_, "b i j -> b j i")
+        h_ = torch.einsum("bij,bjk->bik", v, w_)
+        h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+        h_ = self.proj_out(h_)
+
+        return x + h_
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
+        super().__init__()
+        print(
+            f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+            f"{heads} heads."
+        )
+        inner_dim = dim_head * heads
+        context_dim = default(context_dim, query_dim)
+
+        self.heads = heads
+        self.dim_head = dim_head
+        
+        self.with_ip = kwargs.get("with_ip", False)
+        if self.with_ip and (context_dim is not None):
+            self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
+            self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
+            self.ip_dim= kwargs.get("ip_dim", 16)
+            self.ip_weight = kwargs.get("ip_weight", 1.0)
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+        self.to_out = nn.Sequential(
+            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+        )
+        self.attention_op: Optional[Any] = None
+
+    def forward(self, x, context=None, mask=None):
+        q = self.to_q(x)
+        
+        has_ip = self.with_ip and (context is not None)
+        if has_ip:
+            # context dim [(b frame_num), (77 + img_token), 1024]
+            token_len = context.shape[1]
+            context_ip = context[:, -self.ip_dim:, :]
+            k_ip = self.to_k_ip(context_ip)
+            v_ip = self.to_v_ip(context_ip)
+            context = context[:, :(token_len - self.ip_dim), :]
+
+        context = default(context, x)
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        b, _, _ = q.shape
+        q, k, v = map(
+            lambda t: t.unsqueeze(3)
+            .reshape(b, t.shape[1], self.heads, self.dim_head)
+            .permute(0, 2, 1, 3)
+            .reshape(b * self.heads, t.shape[1], self.dim_head)
+            .contiguous(),
+            (q, k, v),
+        )
+
+        # actually compute the attention, what we cannot get enough of
+        out = xformers.ops.memory_efficient_attention(
+            q, k, v, attn_bias=None, op=self.attention_op
+        )
+        
+        if has_ip:
+            k_ip, v_ip = map(
+                lambda t: t.unsqueeze(3)
+                .reshape(b, t.shape[1], self.heads, self.dim_head)
+                .permute(0, 2, 1, 3)
+                .reshape(b * self.heads, t.shape[1], self.dim_head)
+                .contiguous(),
+                (k_ip, v_ip),
+            )
+            # actually compute the attention, what we cannot get enough of
+            out_ip = xformers.ops.memory_efficient_attention(
+                q, k_ip, v_ip, attn_bias=None, op=self.attention_op
+            )
+            out = out + self.ip_weight * out_ip
+            
+        if exists(mask):
+            raise NotImplementedError
+        out = (
+            out.unsqueeze(0)
+            .reshape(b, self.heads, out.shape[1], self.dim_head)
+            .permute(0, 2, 1, 3)
+            .reshape(b, out.shape[1], self.heads * self.dim_head)
+        )
+        return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+    def __init__(
+        self,
+        dim,
+        n_heads,
+        d_head,
+        dropout=0.0,
+        context_dim=None,
+        gated_ff=True,
+        checkpoint=True,
+        disable_self_attn=False,
+        **kwargs
+    ):
+        super().__init__()
+        assert XFORMERS_IS_AVAILBLE, "xformers is not available"
+        attn_cls = MemoryEfficientCrossAttention
+        self.disable_self_attn = disable_self_attn
+        self.attn1 = attn_cls(
+            query_dim=dim,
+            heads=n_heads,
+            dim_head=d_head,
+            dropout=dropout,
+            context_dim=context_dim if self.disable_self_attn else None,
+        )  # is a self-attention if not self.disable_self_attn
+        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+        self.attn2 = attn_cls(
+            query_dim=dim,
+            context_dim=context_dim,
+            heads=n_heads,
+            dim_head=d_head,
+            dropout=dropout,
+            **kwargs
+        )  # is self-attn if context is none
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+        self.norm3 = nn.LayerNorm(dim)
+        self.checkpoint = checkpoint
+
+    def forward(self, x, context=None):
+        return checkpoint(
+            self._forward, (x, context), self.parameters(), self.checkpoint
+        )
+
+    def _forward(self, x, context=None):
+        x = (
+            self.attn1(
+                self.norm1(x), context=context if self.disable_self_attn else None
+            )
+            + x
+        )
+        x = self.attn2(self.norm2(x), context=context) + x
+        x = self.ff(self.norm3(x)) + x
+        return x
+
+
+class SpatialTransformer(nn.Module):
+    """
+    Transformer block for image-like data.
+    First, project the input (aka embedding)
+    and reshape to b, t, d.
+    Then apply standard transformer action.
+    Finally, reshape to image
+    NEW: use_linear for more efficiency instead of the 1x1 convs
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        n_heads,
+        d_head,
+        depth=1,
+        dropout=0.0,
+        context_dim=None,
+        disable_self_attn=False,
+        use_linear=False,
+        use_checkpoint=True,
+        **kwargs
+    ):
+        super().__init__()
+        if exists(context_dim) and not isinstance(context_dim, list):
+            context_dim = [context_dim]
+        self.in_channels = in_channels
+        inner_dim = n_heads * d_head
+        self.norm = Normalize(in_channels)
+        if not use_linear:
+            self.proj_in = nn.Conv2d(
+                in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+            )
+        else:
+            self.proj_in = nn.Linear(in_channels, inner_dim)
+
+        self.transformer_blocks = nn.ModuleList(
+            [
+                BasicTransformerBlock(
+                    inner_dim,
+                    n_heads,
+                    d_head,
+                    dropout=dropout,
+                    context_dim=context_dim[d],
+                    disable_self_attn=disable_self_attn,
+                    checkpoint=use_checkpoint,
+                    **kwargs
+                )
+                for d in range(depth)
+            ]
+        )
+        if not use_linear:
+            self.proj_out = zero_module(
+                nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+            )
+        else:
+            self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+        self.use_linear = use_linear
+
+    def forward(self, x, context=None):
+        # note: if no context is given, cross-attention defaults to self-attention
+        if not isinstance(context, list):
+            context = [context]
+        b, c, h, w = x.shape
+        x_in = x
+        x = self.norm(x)
+        if not self.use_linear:
+            x = self.proj_in(x)
+        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+        if self.use_linear:
+            x = self.proj_in(x)
+        for i, block in enumerate(self.transformer_blocks):
+            x = block(x, context=context[i])
+        if self.use_linear:
+            x = self.proj_out(x)
+        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+        if not self.use_linear:
+            x = self.proj_out(x)
+        return x + x_in
+
+
+class BasicTransformerBlock3D(BasicTransformerBlock):
+    def forward(self, x, context=None, num_frames=1):
+        return checkpoint(
+            self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
+        )
+
+    def _forward(self, x, context=None, num_frames=1):
+        x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
+        x = (
+            self.attn1(
+                self.norm1(x), 
+                context=context if self.disable_self_attn else None
+            )
+            + x
+        )
+        x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
+        x = self.attn2(self.norm2(x), context=context) + x
+        x = self.ff(self.norm3(x)) + x
+        return x
+
+
+class SpatialTransformer3D(nn.Module):
+    """3D self-attention"""
+
+    def __init__(
+        self,
+        in_channels,
+        n_heads,
+        d_head,
+        depth=1,
+        dropout=0.0,
+        context_dim=None,
+        disable_self_attn=False,
+        use_linear=False,
+        use_checkpoint=True,
+        **kwargs
+    ):
+        super().__init__()
+        if exists(context_dim) and not isinstance(context_dim, list):
+            context_dim = [context_dim]
+        self.in_channels = in_channels
+        inner_dim = n_heads * d_head
+        self.norm = Normalize(in_channels)
+        if not use_linear:
+            self.proj_in = nn.Conv2d(
+                in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+            )
+        else:
+            self.proj_in = nn.Linear(in_channels, inner_dim)
+
+        self.transformer_blocks = nn.ModuleList(
+            [
+                BasicTransformerBlock3D(
+                    inner_dim,
+                    n_heads,
+                    d_head,
+                    dropout=dropout,
+                    context_dim=context_dim[d],
+                    disable_self_attn=disable_self_attn,
+                    checkpoint=use_checkpoint,
+                    **kwargs
+                )
+                for d in range(depth)
+            ]
+        )
+        if not use_linear:
+            self.proj_out = zero_module(
+                nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+            )
+        else:
+            self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+        self.use_linear = use_linear
+
+    def forward(self, x, context=None, num_frames=1):
+        # note: if no context is given, cross-attention defaults to self-attention
+        if not isinstance(context, list):
+            context = [context]
+        b, c, h, w = x.shape
+        x_in = x
+        x = self.norm(x)
+        if not self.use_linear:
+            x = self.proj_in(x)
+        x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+        if self.use_linear:
+            x = self.proj_in(x)
+        for i, block in enumerate(self.transformer_blocks):
+            x = block(x, context=context[i], num_frames=num_frames)
+        if self.use_linear:
+            x = self.proj_out(x)
+        x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+        if not self.use_linear:
+            x = self.proj_out(x)
+        return x + x_in
diff --git a/imagedream/ldm/modules/diffusionmodules/__init__.py b/imagedream/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imagedream/ldm/modules/diffusionmodules/adaptors.py b/imagedream/ldm/modules/diffusionmodules/adaptors.py
new file mode 100644
index 0000000000000000000000000000000000000000..547e7f4173c85f7f62ff76e38e9794ee2f31660e
--- /dev/null
+++ b/imagedream/ldm/modules/diffusionmodules/adaptors.py
@@ -0,0 +1,163 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+import math
+
+import torch
+import torch.nn as nn
+
+
+# FFN
+def FeedForward(dim, mult=4):
+    inner_dim = int(dim * mult)
+    return nn.Sequential(
+        nn.LayerNorm(dim),
+        nn.Linear(dim, inner_dim, bias=False),
+        nn.GELU(),
+        nn.Linear(inner_dim, dim, bias=False),
+    )
+    
+    
+def reshape_tensor(x, heads):
+    bs, length, width = x.shape
+    #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
+    x = x.view(bs, length, heads, -1)
+    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+    x = x.transpose(1, 2)
+    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+    x = x.reshape(bs, heads, length, -1)
+    return x
+
+
+class PerceiverAttention(nn.Module):
+    def __init__(self, *, dim, dim_head=64, heads=8):
+        super().__init__()
+        self.scale = dim_head**-0.5
+        self.dim_head = dim_head
+        self.heads = heads
+        inner_dim = dim_head * heads
+
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+
+        self.to_q = nn.Linear(dim, inner_dim, bias=False)
+        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+        self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+
+    def forward(self, x, latents):
+        """
+        Args:
+            x (torch.Tensor): image features
+                shape (b, n1, D)
+            latent (torch.Tensor): latent features
+                shape (b, n2, D)
+        """
+        x = self.norm1(x)
+        latents = self.norm2(latents)
+        
+        b, l, _ = latents.shape
+
+        q = self.to_q(latents)
+        kv_input = torch.cat((x, latents), dim=-2)
+        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+        
+        q = reshape_tensor(q, self.heads)
+        k = reshape_tensor(k, self.heads)
+        v = reshape_tensor(v, self.heads)
+
+        # attention
+        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+        weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+        out = weight @ v
+        
+        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+        return self.to_out(out)
+    
+    
+class ImageProjModel(torch.nn.Module):
+    """Projection Model"""
+    def __init__(self, 
+            cross_attention_dim=1024, 
+            clip_embeddings_dim=1024, 
+            clip_extra_context_tokens=4):
+        super().__init__()
+        self.cross_attention_dim = cross_attention_dim
+        self.clip_extra_context_tokens = clip_extra_context_tokens
+
+        # from 1024 -> 4 * 1024
+        self.proj = torch.nn.Linear(
+            clip_embeddings_dim, 
+            self.clip_extra_context_tokens * cross_attention_dim)
+        self.norm = torch.nn.LayerNorm(cross_attention_dim)
+        
+    def forward(self, image_embeds):
+        embeds = image_embeds
+        clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
+        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+        return clip_extra_context_tokens
+    
+
+class SimpleReSampler(nn.Module):
+    def __init__(self, embedding_dim=1280, output_dim=1024):
+        super().__init__()
+        self.proj_out = nn.Linear(embedding_dim, output_dim)
+        self.norm_out = nn.LayerNorm(output_dim)
+
+    def forward(self, latents):
+        """
+            latents: B 256 N
+        """
+        latents = self.proj_out(latents)
+        return self.norm_out(latents)
+
+
+class Resampler(nn.Module):
+    def __init__(
+        self,
+        dim=1024,
+        depth=8,
+        dim_head=64,
+        heads=16,
+        num_queries=8,
+        embedding_dim=768,
+        output_dim=1024,
+        ff_mult=4,
+    ):
+        super().__init__()
+        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+        self.proj_in = nn.Linear(embedding_dim, dim)
+        self.proj_out = nn.Linear(dim, output_dim)
+        self.norm_out = nn.LayerNorm(output_dim)
+
+        self.layers = nn.ModuleList([])
+        for _ in range(depth):
+            self.layers.append(
+                nn.ModuleList(
+                    [
+                        PerceiverAttention(dim=dim, 
+                                           dim_head=dim_head, 
+                                           heads=heads),
+                        FeedForward(dim=dim, mult=ff_mult),
+                    ]
+                )
+            )
+
+    def forward(self, x):
+        latents = self.latents.repeat(x.size(0), 1, 1)
+        x = self.proj_in(x)
+        for attn, ff in self.layers:
+            latents = attn(x, latents) + latents
+            latents = ff(latents) + latents
+            
+        latents = self.proj_out(latents)
+        return self.norm_out(latents)
+
+
+if __name__ == '__main__':
+    resampler = Resampler(embedding_dim=1280)
+    resampler = SimpleReSampler(embedding_dim=1280)
+    tensor = torch.rand(4, 257, 1280)
+    embed = resampler(tensor)
+    # embed = (tensor)
+    print(embed.shape)
diff --git a/imagedream/ldm/modules/diffusionmodules/model.py b/imagedream/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..090df6b33cccee7950a1d5719ec7513cc61d588f
--- /dev/null
+++ b/imagedream/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,1018 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+from ..attention import MemoryEfficientCrossAttention
+
+try:
+    import xformers
+    import xformers.ops
+
+    XFORMERS_IS_AVAILBLE = True
+except:
+    XFORMERS_IS_AVAILBLE = False
+    print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models:
+    From Fairseq.
+    Build sinusoidal embeddings.
+    This matches the implementation in tensor2tensor, but differs slightly
+    from the description in Section 3.5 of "Attention Is All You Need".
+    """
+    assert len(timesteps.shape) == 1
+
+    half_dim = embedding_dim // 2
+    emb = math.log(10000) / (half_dim - 1)
+    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+    emb = emb.to(device=timesteps.device)
+    emb = timesteps.float()[:, None] * emb[None, :]
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+    if embedding_dim % 2 == 1:  # zero pad
+        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+    return emb
+
+
+def nonlinearity(x):
+    # swish
+    return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+    return torch.nn.GroupNorm(
+        num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+    )
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = torch.nn.Conv2d(
+                in_channels, in_channels, kernel_size=3, stride=1, padding=1
+            )
+
+    def forward(self, x):
+        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(
+                in_channels, in_channels, kernel_size=3, stride=2, padding=0
+            )
+
+    def forward(self, x):
+        if self.with_conv:
+            pad = (0, 1, 0, 1)
+            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+            x = self.conv(x)
+        else:
+            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+        return x
+
+
+class ResnetBlock(nn.Module):
+    def __init__(
+        self,
+        *,
+        in_channels,
+        out_channels=None,
+        conv_shortcut=False,
+        dropout,
+        temb_channels=512,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = Normalize(in_channels)
+        self.conv1 = torch.nn.Conv2d(
+            in_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+        if temb_channels > 0:
+            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+        self.norm2 = Normalize(out_channels)
+        self.dropout = torch.nn.Dropout(dropout)
+        self.conv2 = torch.nn.Conv2d(
+            out_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv2d(
+                    in_channels, out_channels, kernel_size=3, stride=1, padding=1
+                )
+            else:
+                self.nin_shortcut = torch.nn.Conv2d(
+                    in_channels, out_channels, kernel_size=1, stride=1, padding=0
+                )
+
+    def forward(self, x, temb):
+        h = x
+        h = self.norm1(h)
+        h = nonlinearity(h)
+        h = self.conv1(h)
+
+        if temb is not None:
+            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+        h = self.norm2(h)
+        h = nonlinearity(h)
+        h = self.dropout(h)
+        h = self.conv2(h)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                x = self.conv_shortcut(x)
+            else:
+                x = self.nin_shortcut(x)
+
+        return x + h
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.k = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.v = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.proj_out = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b, c, h, w = q.shape
+        q = q.reshape(b, c, h * w)
+        q = q.permute(0, 2, 1)  # b,hw,c
+        k = k.reshape(b, c, h * w)  # b,c,hw
+        w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+        w_ = w_ * (int(c) ** (-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b, c, h * w)
+        w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
+        h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+        h_ = h_.reshape(b, c, h, w)
+
+        h_ = self.proj_out(h_)
+
+        return x + h_
+
+
+class MemoryEfficientAttnBlock(nn.Module):
+    """
+    Uses xformers efficient implementation,
+    see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+    Note: this is a single-head self-attention operation
+    """
+
+    #
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.k = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.v = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.proj_out = torch.nn.Conv2d(
+            in_channels, in_channels, kernel_size=1, stride=1, padding=0
+        )
+        self.attention_op: Optional[Any] = None
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        B, C, H, W = q.shape
+        q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
+
+        q, k, v = map(
+            lambda t: t.unsqueeze(3)
+            .reshape(B, t.shape[1], 1, C)
+            .permute(0, 2, 1, 3)
+            .reshape(B * 1, t.shape[1], C)
+            .contiguous(),
+            (q, k, v),
+        )
+        out = xformers.ops.memory_efficient_attention(
+            q, k, v, attn_bias=None, op=self.attention_op
+        )
+
+        out = (
+            out.unsqueeze(0)
+            .reshape(B, 1, out.shape[1], C)
+            .permute(0, 2, 1, 3)
+            .reshape(B, out.shape[1], C)
+        )
+        out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
+        out = self.proj_out(out)
+        return x + out
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+    def forward(self, x, context=None, mask=None):
+        b, c, h, w = x.shape
+        x = rearrange(x, "b c h w -> b (h w) c")
+        out = super().forward(x, context=context, mask=mask)
+        out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
+        return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+    assert attn_type in [
+        "vanilla",
+        "vanilla-xformers",
+        "memory-efficient-cross-attn",
+        "linear",
+        "none",
+    ], f"attn_type {attn_type} unknown"
+    if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+        attn_type = "vanilla-xformers"
+    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+    if attn_type == "vanilla":
+        assert attn_kwargs is None
+        return AttnBlock(in_channels)
+    elif attn_type == "vanilla-xformers":
+        print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+        return MemoryEfficientAttnBlock(in_channels)
+    elif type == "memory-efficient-cross-attn":
+        attn_kwargs["query_dim"] = in_channels
+        return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+    elif attn_type == "none":
+        return nn.Identity(in_channels)
+    else:
+        raise NotImplementedError()
+
+
+class Model(nn.Module):
+    def __init__(
+        self,
+        *,
+        ch,
+        out_ch,
+        ch_mult=(1, 2, 4, 8),
+        num_res_blocks,
+        attn_resolutions,
+        dropout=0.0,
+        resamp_with_conv=True,
+        in_channels,
+        resolution,
+        use_timestep=True,
+        use_linear_attn=False,
+        attn_type="vanilla",
+    ):
+        super().__init__()
+        if use_linear_attn:
+            attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = self.ch * 4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList(
+                [
+                    torch.nn.Linear(self.ch, self.temb_ch),
+                    torch.nn.Linear(self.temb_ch, self.temb_ch),
+                ]
+            )
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(
+            in_channels, self.ch, kernel_size=3, stride=1, padding=1
+        )
+
+        curr_res = resolution
+        in_ch_mult = (1,) + tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch * in_ch_mult[i_level]
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(
+                    ResnetBlock(
+                        in_channels=block_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions - 1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch * ch_mult[i_level]
+            skip_in = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch * in_ch_mult[i_level]
+                block.append(
+                    ResnetBlock(
+                        in_channels=block_in + skip_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up)  # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(
+            block_in, out_ch, kernel_size=3, stride=1, padding=1
+        )
+
+    def forward(self, x, t=None, context=None):
+        # assert x.shape[2] == x.shape[3] == self.resolution
+        if context is not None:
+            # assume aligned context, cat along channel axis
+            x = torch.cat((x, context), dim=1)
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions - 1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb
+                )
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+    def get_last_layer(self):
+        return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+    def __init__(
+        self,
+        *,
+        ch,
+        out_ch,
+        ch_mult=(1, 2, 4, 8),
+        num_res_blocks,
+        attn_resolutions,
+        dropout=0.0,
+        resamp_with_conv=True,
+        in_channels,
+        resolution,
+        z_channels,
+        double_z=True,
+        use_linear_attn=False,
+        attn_type="vanilla",
+        **ignore_kwargs,
+    ):
+        super().__init__()
+        if use_linear_attn:
+            attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(
+            in_channels, self.ch, kernel_size=3, stride=1, padding=1
+        )
+
+        curr_res = resolution
+        in_ch_mult = (1,) + tuple(ch_mult)
+        self.in_ch_mult = in_ch_mult
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch * in_ch_mult[i_level]
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(
+                    ResnetBlock(
+                        in_channels=block_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions - 1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(
+            block_in,
+            2 * z_channels if double_z else z_channels,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+
+    def forward(self, x):
+        # timestep embedding
+        temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions - 1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class Decoder(nn.Module):
+    def __init__(
+        self,
+        *,
+        ch,
+        out_ch,
+        ch_mult=(1, 2, 4, 8),
+        num_res_blocks,
+        attn_resolutions,
+        dropout=0.0,
+        resamp_with_conv=True,
+        in_channels,
+        resolution,
+        z_channels,
+        give_pre_end=False,
+        tanh_out=False,
+        use_linear_attn=False,
+        attn_type="vanilla",
+        **ignorekwargs,
+    ):
+        super().__init__()
+        if use_linear_attn:
+            attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.give_pre_end = give_pre_end
+        self.tanh_out = tanh_out
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        in_ch_mult = (1,) + tuple(ch_mult)
+        block_in = ch * ch_mult[self.num_resolutions - 1]
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.z_shape = (1, z_channels, curr_res, curr_res)
+        print(
+            "Working with z of shape {} = {} dimensions.".format(
+                self.z_shape, np.prod(self.z_shape)
+            )
+        )
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(
+            z_channels, block_in, kernel_size=3, stride=1, padding=1
+        )
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(
+            in_channels=block_in,
+            out_channels=block_in,
+            temb_channels=self.temb_ch,
+            dropout=dropout,
+        )
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                block.append(
+                    ResnetBlock(
+                        in_channels=block_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up)  # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(
+            block_in, out_ch, kernel_size=3, stride=1, padding=1
+        )
+
+    def forward(self, z):
+        # assert z.shape[1:] == self.z_shape[1:]
+        self.last_z_shape = z.shape
+
+        # timestep embedding
+        temb = None
+
+        # z to block_in
+        h = self.conv_in(z)
+
+        # middle
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.up[i_level].block[i_block](h, temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        if self.give_pre_end:
+            return h
+
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        if self.tanh_out:
+            h = torch.tanh(h)
+        return h
+
+
+class SimpleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, *args, **kwargs):
+        super().__init__()
+        self.model = nn.ModuleList(
+            [
+                nn.Conv2d(in_channels, in_channels, 1),
+                ResnetBlock(
+                    in_channels=in_channels,
+                    out_channels=2 * in_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                ),
+                ResnetBlock(
+                    in_channels=2 * in_channels,
+                    out_channels=4 * in_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                ),
+                ResnetBlock(
+                    in_channels=4 * in_channels,
+                    out_channels=2 * in_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                ),
+                nn.Conv2d(2 * in_channels, in_channels, 1),
+                Upsample(in_channels, with_conv=True),
+            ]
+        )
+        # end
+        self.norm_out = Normalize(in_channels)
+        self.conv_out = torch.nn.Conv2d(
+            in_channels, out_channels, kernel_size=3, stride=1, padding=1
+        )
+
+    def forward(self, x):
+        for i, layer in enumerate(self.model):
+            if i in [1, 2, 3]:
+                x = layer(x, None)
+            else:
+                x = layer(x)
+
+        h = self.norm_out(x)
+        h = nonlinearity(h)
+        x = self.conv_out(h)
+        return x
+
+
+class UpsampleDecoder(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        ch,
+        num_res_blocks,
+        resolution,
+        ch_mult=(2, 2),
+        dropout=0.0,
+    ):
+        super().__init__()
+        # upsampling
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        block_in = in_channels
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.res_blocks = nn.ModuleList()
+        self.upsample_blocks = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            res_block = []
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                res_block.append(
+                    ResnetBlock(
+                        in_channels=block_in,
+                        out_channels=block_out,
+                        temb_channels=self.temb_ch,
+                        dropout=dropout,
+                    )
+                )
+                block_in = block_out
+            self.res_blocks.append(nn.ModuleList(res_block))
+            if i_level != self.num_resolutions - 1:
+                self.upsample_blocks.append(Upsample(block_in, True))
+                curr_res = curr_res * 2
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(
+            block_in, out_channels, kernel_size=3, stride=1, padding=1
+        )
+
+    def forward(self, x):
+        # upsampling
+        h = x
+        for k, i_level in enumerate(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.res_blocks[i_level][i_block](h, None)
+            if i_level != self.num_resolutions - 1:
+                h = self.upsample_blocks[k](h)
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class LatentRescaler(nn.Module):
+    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+        super().__init__()
+        # residual block, interpolate, residual block
+        self.factor = factor
+        self.conv_in = nn.Conv2d(
+            in_channels, mid_channels, kernel_size=3, stride=1, padding=1
+        )
+        self.res_block1 = nn.ModuleList(
+            [
+                ResnetBlock(
+                    in_channels=mid_channels,
+                    out_channels=mid_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                )
+                for _ in range(depth)
+            ]
+        )
+        self.attn = AttnBlock(mid_channels)
+        self.res_block2 = nn.ModuleList(
+            [
+                ResnetBlock(
+                    in_channels=mid_channels,
+                    out_channels=mid_channels,
+                    temb_channels=0,
+                    dropout=0.0,
+                )
+                for _ in range(depth)
+            ]
+        )
+
+        self.conv_out = nn.Conv2d(
+            mid_channels,
+            out_channels,
+            kernel_size=1,
+        )
+
+    def forward(self, x):
+        x = self.conv_in(x)
+        for block in self.res_block1:
+            x = block(x, None)
+        x = torch.nn.functional.interpolate(
+            x,
+            size=(
+                int(round(x.shape[2] * self.factor)),
+                int(round(x.shape[3] * self.factor)),
+            ),
+        )
+        x = self.attn(x)
+        for block in self.res_block2:
+            x = block(x, None)
+        x = self.conv_out(x)
+        return x
+
+
+class MergedRescaleEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        ch,
+        resolution,
+        out_ch,
+        num_res_blocks,
+        attn_resolutions,
+        dropout=0.0,
+        resamp_with_conv=True,
+        ch_mult=(1, 2, 4, 8),
+        rescale_factor=1.0,
+        rescale_module_depth=1,
+    ):
+        super().__init__()
+        intermediate_chn = ch * ch_mult[-1]
+        self.encoder = Encoder(
+            in_channels=in_channels,
+            num_res_blocks=num_res_blocks,
+            ch=ch,
+            ch_mult=ch_mult,
+            z_channels=intermediate_chn,
+            double_z=False,
+            resolution=resolution,
+            attn_resolutions=attn_resolutions,
+            dropout=dropout,
+            resamp_with_conv=resamp_with_conv,
+            out_ch=None,
+        )
+        self.rescaler = LatentRescaler(
+            factor=rescale_factor,
+            in_channels=intermediate_chn,
+            mid_channels=intermediate_chn,
+            out_channels=out_ch,
+            depth=rescale_module_depth,
+        )
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.rescaler(x)
+        return x
+
+
+class MergedRescaleDecoder(nn.Module):
+    def __init__(
+        self,
+        z_channels,
+        out_ch,
+        resolution,
+        num_res_blocks,
+        attn_resolutions,
+        ch,
+        ch_mult=(1, 2, 4, 8),
+        dropout=0.0,
+        resamp_with_conv=True,
+        rescale_factor=1.0,
+        rescale_module_depth=1,
+    ):
+        super().__init__()
+        tmp_chn = z_channels * ch_mult[-1]
+        self.decoder = Decoder(
+            out_ch=out_ch,
+            z_channels=tmp_chn,
+            attn_resolutions=attn_resolutions,
+            dropout=dropout,
+            resamp_with_conv=resamp_with_conv,
+            in_channels=None,
+            num_res_blocks=num_res_blocks,
+            ch_mult=ch_mult,
+            resolution=resolution,
+            ch=ch,
+        )
+        self.rescaler = LatentRescaler(
+            factor=rescale_factor,
+            in_channels=z_channels,
+            mid_channels=tmp_chn,
+            out_channels=tmp_chn,
+            depth=rescale_module_depth,
+        )
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Upsampler(nn.Module):
+    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+        super().__init__()
+        assert out_size >= in_size
+        num_blocks = int(np.log2(out_size // in_size)) + 1
+        factor_up = 1.0 + (out_size % in_size)
+        print(
+            f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
+        )
+        self.rescaler = LatentRescaler(
+            factor=factor_up,
+            in_channels=in_channels,
+            mid_channels=2 * in_channels,
+            out_channels=in_channels,
+        )
+        self.decoder = Decoder(
+            out_ch=out_channels,
+            resolution=out_size,
+            z_channels=in_channels,
+            num_res_blocks=2,
+            attn_resolutions=[],
+            in_channels=None,
+            ch=in_channels,
+            ch_mult=[ch_mult for _ in range(num_blocks)],
+        )
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Resize(nn.Module):
+    def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+        super().__init__()
+        self.with_conv = learned
+        self.mode = mode
+        if self.with_conv:
+            print(
+                f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
+            )
+            raise NotImplementedError()
+            assert in_channels is not None
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(
+                in_channels, in_channels, kernel_size=4, stride=2, padding=1
+            )
+
+    def forward(self, x, scale_factor=1.0):
+        if scale_factor == 1.0:
+            return x
+        else:
+            x = torch.nn.functional.interpolate(
+                x, mode=self.mode, align_corners=False, scale_factor=scale_factor
+            )
+        return x
diff --git a/imagedream/ldm/modules/diffusionmodules/openaimodel.py b/imagedream/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d28ecbc9718438a8c6b720d78b2bf52f641d4cc
--- /dev/null
+++ b/imagedream/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,1135 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+from imagedream.ldm.modules.diffusionmodules.util import (
+    checkpoint,
+    conv_nd,
+    linear,
+    avg_pool_nd,
+    zero_module,
+    normalization,
+    timestep_embedding,
+    convert_module_to_f16, 
+    convert_module_to_f32
+)
+from imagedream.ldm.modules.attention import (
+    SpatialTransformer, 
+    SpatialTransformer3D, 
+    exists
+)
+from imagedream.ldm.modules.diffusionmodules.adaptors import (
+    Resampler, 
+    ImageProjModel
+)
+
+## go
+class AttentionPool2d(nn.Module):
+    """
+    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+    """
+
+    def __init__(
+        self,
+        spacial_dim: int,
+        embed_dim: int,
+        num_heads_channels: int,
+        output_dim: int = None,
+    ):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(
+            th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
+        )
+        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+        self.num_heads = embed_dim // num_heads_channels
+        self.attention = QKVAttention(self.num_heads)
+
+    def forward(self, x):
+        b, c, *_spatial = x.shape
+        x = x.reshape(b, c, -1)  # NC(HW)
+        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
+        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
+        x = self.qkv_proj(x)
+        x = self.attention(x)
+        x = self.c_proj(x)
+        return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+    """
+    Any module where forward() takes timestep embeddings as a second argument.
+    """
+
+    @abstractmethod
+    def forward(self, x, emb):
+        """
+        Apply the module to `x` given `emb` timestep embeddings.
+        """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+    """
+    A sequential module that passes timestep embeddings to the children that
+    support it as an extra input.
+    """
+
+    def forward(self, x, emb, context=None, num_frames=1):
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb)
+            elif isinstance(layer, SpatialTransformer3D):
+                x = layer(x, context, num_frames=num_frames)
+            elif isinstance(layer, SpatialTransformer):
+                x = layer(x, context)
+            else:
+                x = layer(x)
+        return x
+
+
+class Upsample(nn.Module):
+    """
+    An upsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 upsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        if use_conv:
+            self.conv = conv_nd(
+                dims, self.channels, self.out_channels, 3, padding=padding
+            )
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        if self.dims == 3:
+            x = F.interpolate(
+                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+            )
+        else:
+            x = F.interpolate(x, scale_factor=2, mode="nearest")
+        if self.use_conv:
+            x = self.conv(x)
+        return x
+
+
+class TransposedUpsample(nn.Module):
+    "Learned 2x upsampling without padding"
+
+    def __init__(self, channels, out_channels=None, ks=5):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+
+        self.up = nn.ConvTranspose2d(
+            self.channels, self.out_channels, kernel_size=ks, stride=2
+        )
+
+    def forward(self, x):
+        return self.up(x)
+
+
+class Downsample(nn.Module):
+    """
+    A downsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 downsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        stride = 2 if dims != 3 else (1, 2, 2)
+        if use_conv:
+            self.op = conv_nd(
+                dims,
+                self.channels,
+                self.out_channels,
+                3,
+                stride=stride,
+                padding=padding,
+            )
+        else:
+            assert self.channels == self.out_channels
+            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+    """
+    A residual block that can optionally change the number of channels.
+    :param channels: the number of input channels.
+    :param emb_channels: the number of timestep embedding channels.
+    :param dropout: the rate of dropout.
+    :param out_channels: if specified, the number of out channels.
+    :param use_conv: if True and out_channels is specified, use a spatial
+        convolution instead of a smaller 1x1 convolution to change the
+        channels in the skip connection.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param use_checkpoint: if True, use gradient checkpointing on this module.
+    :param up: if True, use this block for upsampling.
+    :param down: if True, use this block for downsampling.
+    """
+
+    def __init__(
+        self,
+        channels,
+        emb_channels,
+        dropout,
+        out_channels=None,
+        use_conv=False,
+        use_scale_shift_norm=False,
+        dims=2,
+        use_checkpoint=False,
+        up=False,
+        down=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.emb_channels = emb_channels
+        self.dropout = dropout
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.use_checkpoint = use_checkpoint
+        self.use_scale_shift_norm = use_scale_shift_norm
+
+        self.in_layers = nn.Sequential(
+            normalization(channels),
+            nn.SiLU(),
+            conv_nd(dims, channels, self.out_channels, 3, padding=1),
+        )
+
+        self.updown = up or down
+
+        if up:
+            self.h_upd = Upsample(channels, False, dims)
+            self.x_upd = Upsample(channels, False, dims)
+        elif down:
+            self.h_upd = Downsample(channels, False, dims)
+            self.x_upd = Downsample(channels, False, dims)
+        else:
+            self.h_upd = self.x_upd = nn.Identity()
+
+        self.emb_layers = nn.Sequential(
+            nn.SiLU(),
+            linear(
+                emb_channels,
+                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+            ),
+        )
+        self.out_layers = nn.Sequential(
+            normalization(self.out_channels),
+            nn.SiLU(),
+            nn.Dropout(p=dropout),
+            zero_module(
+                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+            ),
+        )
+
+        if self.out_channels == channels:
+            self.skip_connection = nn.Identity()
+        elif use_conv:
+            self.skip_connection = conv_nd(
+                dims, channels, self.out_channels, 3, padding=1
+            )
+        else:
+            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+    def forward(self, x, emb):
+        """
+        Apply the block to a Tensor, conditioned on a timestep embedding.
+        :param x: an [N x C x ...] Tensor of features.
+        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        return checkpoint(
+            self._forward, (x, emb), self.parameters(), self.use_checkpoint
+        )
+
+    def _forward(self, x, emb):
+        if self.updown:
+            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+            h = in_rest(x)
+            h = self.h_upd(h)
+            x = self.x_upd(x)
+            h = in_conv(h)
+        else:
+            h = self.in_layers(x)
+        emb_out = self.emb_layers(emb).type(h.dtype)
+        while len(emb_out.shape) < len(h.shape):
+            emb_out = emb_out[..., None]
+        if self.use_scale_shift_norm:
+            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+            scale, shift = th.chunk(emb_out, 2, dim=1)
+            h = out_norm(h) * (1 + scale) + shift
+            h = out_rest(h)
+        else:
+            h = h + emb_out
+            h = self.out_layers(h)
+        return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+    """
+    An attention block that allows spatial positions to attend to each other.
+    Originally ported from here, but adapted to the N-d case.
+    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+    """
+
+    def __init__(
+        self,
+        channels,
+        num_heads=1,
+        num_head_channels=-1,
+        use_checkpoint=False,
+        use_new_attention_order=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        if num_head_channels == -1:
+            self.num_heads = num_heads
+        else:
+            assert (
+                channels % num_head_channels == 0
+            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+            self.num_heads = channels // num_head_channels
+        self.use_checkpoint = use_checkpoint
+        self.norm = normalization(channels)
+        self.qkv = conv_nd(1, channels, channels * 3, 1)
+        if use_new_attention_order:
+            # split qkv before split heads
+            self.attention = QKVAttention(self.num_heads)
+        else:
+            # split heads before split qkv
+            self.attention = QKVAttentionLegacy(self.num_heads)
+
+        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+    def forward(self, x):
+        return checkpoint(
+            self._forward, (x,), self.parameters(), True
+        )  # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+        # return pt_checkpoint(self._forward, x)  # pytorch
+
+    def _forward(self, x):
+        b, c, *spatial = x.shape
+        x = x.reshape(b, c, -1)
+        qkv = self.qkv(self.norm(x))
+        h = self.attention(qkv)
+        h = self.proj_out(h)
+        return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+    """
+    A counter for the `thop` package to count the operations in an
+    attention operation.
+    Meant to be used like:
+        macs, params = thop.profile(
+            model,
+            inputs=(inputs, timestamps),
+            custom_ops={QKVAttention: QKVAttention.count_flops},
+        )
+    """
+    b, c, *spatial = y[0].shape
+    num_spatial = int(np.prod(spatial))
+    # We perform two matmuls with the same number of ops.
+    # The first computes the weight matrix, the second computes
+    # the combination of the value vectors.
+    matmul_ops = 2 * b * (num_spatial**2) * c
+    model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+    """
+    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts", q * scale, k * scale
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v)
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+    """
+    A module which performs QKV attention and splits in a different order.
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.chunk(3, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts",
+            (q * scale).view(bs * self.n_heads, ch, length),
+            (k * scale).view(bs * self.n_heads, ch, length),
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class Timestep(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.dim = dim
+
+    def forward(self, t):
+        return timestep_embedding(t, self.dim)
+
+
+class MultiViewUNetModel(nn.Module):
+    """
+    The full multi-view UNet model with attention, timestep embedding and camera embedding.
+    :param in_channels: channels in the input Tensor.
+    :param model_channels: base channel count for the model.
+    :param out_channels: channels in the output Tensor.
+    :param num_res_blocks: number of residual blocks per downsample.
+    :param attention_resolutions: a collection of downsample rates at which
+        attention will take place. May be a set, list, or tuple.
+        For example, if this contains 4, then at 4x downsampling, attention
+        will be used.
+    :param dropout: the dropout probability.
+    :param channel_mult: channel multiplier for each level of the UNet.
+    :param conv_resample: if True, use learned convolutions for upsampling and
+        downsampling.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param num_classes: if specified (as an int), then this model will be
+        class-conditional with `num_classes` classes.
+    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+    :param num_heads: the number of attention heads in each attention layer.
+    :param num_heads_channels: if specified, ignore num_heads and instead use
+                               a fixed channel width per attention head.
+    :param num_heads_upsample: works with num_heads to set a different number
+                               of heads for upsampling. Deprecated.
+    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+    :param resblock_updown: use residual blocks for up/downsampling.
+    :param use_new_attention_order: use a different attention pattern for potentially
+                                    increased efficiency.
+    :param camera_dim: dimensionality of camera input.
+    """
+
+    def __init__(
+        self,
+        image_size,
+        in_channels,
+        model_channels,
+        out_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        num_classes=None,
+        use_checkpoint=False,
+        use_fp16=False,
+        use_bf16=False,
+        num_heads=-1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+        use_spatial_transformer=False,  # custom transformer support
+        transformer_depth=1,  # custom transformer support
+        context_dim=None,  # custom transformer support
+        n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
+        legacy=True,
+        disable_self_attentions=None,
+        num_attention_blocks=None,
+        disable_middle_self_attn=False,
+        use_linear_in_transformer=False,
+        adm_in_channels=None,
+        camera_dim=None,
+        with_ip=False,  # wether add image prompt images
+        ip_dim=0,  # number of extra token, 4 for global 16 for local
+        ip_weight=1.0,  # weight for image prompt context
+        ip_mode="local_resample", # which mode of adaptor, global or local
+    ):
+        super().__init__()
+        if use_spatial_transformer:
+            assert (
+                context_dim is not None
+            ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+        if context_dim is not None:
+            assert (
+                use_spatial_transformer
+            ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+            from omegaconf.listconfig import ListConfig
+
+            if type(context_dim) == ListConfig:
+                context_dim = list(context_dim)
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert (
+                num_head_channels != -1
+            ), "Either num_heads or num_head_channels has to be set"
+
+        if num_head_channels == -1:
+            assert (
+                num_heads != -1
+            ), "Either num_heads or num_head_channels has to be set"
+
+        self.image_size = image_size
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.out_channels = out_channels
+        if isinstance(num_res_blocks, int):
+            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+        else:
+            if len(num_res_blocks) != len(channel_mult):
+                raise ValueError(
+                    "provide num_res_blocks either as an int (globally constant) or "
+                    "as a list/tuple (per-level) with the same length as channel_mult"
+                )
+            self.num_res_blocks = num_res_blocks
+        if disable_self_attentions is not None:
+            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+            assert len(disable_self_attentions) == len(channel_mult)
+        if num_attention_blocks is not None:
+            assert len(num_attention_blocks) == len(self.num_res_blocks)
+            assert all(
+                map(
+                    lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+                    range(len(num_attention_blocks)),
+                )
+            )
+            print(
+                f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+                f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+                f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+                f"attention will still not be set."
+            )
+
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.num_classes = num_classes
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.dtype = th.bfloat16 if use_bf16 else self.dtype
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+
+        self.with_ip = with_ip  # wether there is image prompt
+        self.ip_dim = ip_dim  # num of extra token, 4 for global 16 for local
+        self.ip_weight = ip_weight
+        assert ip_mode in ["global", "local_resample"]
+        self.ip_mode = ip_mode  # which mode of adaptor
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        if camera_dim is not None:
+            time_embed_dim = model_channels * 4
+            self.camera_embed = nn.Sequential(
+                linear(camera_dim, time_embed_dim),
+                nn.SiLU(),
+                linear(time_embed_dim, time_embed_dim),
+            )
+
+        if self.num_classes is not None:
+            if isinstance(self.num_classes, int):
+                self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+            elif self.num_classes == "continuous":
+                print("setting up linear c_adm embedding layer")
+                self.label_emb = nn.Linear(1, time_embed_dim)
+            elif self.num_classes == "sequential":
+                assert adm_in_channels is not None
+                self.label_emb = nn.Sequential(
+                    nn.Sequential(
+                        linear(adm_in_channels, time_embed_dim),
+                        nn.SiLU(),
+                        linear(time_embed_dim, time_embed_dim),
+                    )
+                )
+            else:
+                raise ValueError()
+
+        if self.with_ip and (context_dim is not None) and ip_dim > 0:
+            if self.ip_mode == "local_resample":
+                # ip-adapter-plus
+                hidden_dim = 1280
+                self.image_embed = Resampler(
+                    dim=context_dim,
+                    depth=4,
+                    dim_head=64,
+                    heads=12,
+                    num_queries=ip_dim,  # num token
+                    embedding_dim=hidden_dim,
+                    output_dim=context_dim,
+                    ff_mult=4,
+                )
+            elif self.ip_mode == "global":
+                self.image_embed = ImageProjModel(
+                    cross_attention_dim=context_dim, 
+                    clip_extra_context_tokens=ip_dim)         
+            else:
+                raise ValueError(f"{self.ip_mode} is not supported")
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for nr in range(self.num_res_blocks[level]):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        # num_heads = 1
+                        dim_head = (
+                            ch // num_heads
+                            if use_spatial_transformer
+                            else num_head_channels
+                        )
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if (
+                        not exists(num_attention_blocks)
+                        or nr < num_attention_blocks[level]
+                    ):
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            )
+                            if not use_spatial_transformer
+                            else SpatialTransformer3D(
+                                ch,
+                                num_heads,
+                                dim_head,
+                                depth=transformer_depth,
+                                context_dim=context_dim,
+                                disable_self_attn=disabled_sa,
+                                use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint,
+                                with_ip=self.with_ip,
+                                ip_dim=self.ip_dim, 
+                                ip_weight=self.ip_weight
+                            )
+                        )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+                
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            # num_heads = 1
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=dim_head,
+                use_new_attention_order=use_new_attention_order,
+            )
+            if not use_spatial_transformer
+            else SpatialTransformer3D(  # always uses a self-attn
+                ch,
+                num_heads,
+                dim_head,
+                depth=transformer_depth,
+                context_dim=context_dim,
+                disable_self_attn=disable_middle_self_attn,
+                use_linear=use_linear_in_transformer,
+                use_checkpoint=use_checkpoint,
+                with_ip=self.with_ip,
+                ip_dim=self.ip_dim, 
+                ip_weight=self.ip_weight
+            ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self._feature_size += ch
+
+        self.output_blocks = nn.ModuleList([])
+        for level, mult in list(enumerate(channel_mult))[::-1]:
+            for i in range(self.num_res_blocks[level] + 1):
+                ich = input_block_chans.pop()
+                layers = [
+                    ResBlock(
+                        ch + ich,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=model_channels * mult,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = model_channels * mult
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        # num_heads = 1
+                        dim_head = (
+                            ch // num_heads
+                            if use_spatial_transformer
+                            else num_head_channels
+                        )
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if (
+                        not exists(num_attention_blocks)
+                        or i < num_attention_blocks[level]
+                    ):
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads_upsample,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            )
+                            if not use_spatial_transformer
+                            else SpatialTransformer3D(
+                                ch,
+                                num_heads,
+                                dim_head,
+                                depth=transformer_depth,
+                                context_dim=context_dim,
+                                disable_self_attn=disabled_sa,
+                                use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint,
+                                with_ip=self.with_ip,
+                                ip_dim=self.ip_dim, 
+                                ip_weight=self.ip_weight
+                            )
+                        )
+                if level and i == self.num_res_blocks[level]:
+                    out_ch = ch
+                    layers.append(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            up=True,
+                        )
+                        if resblock_updown
+                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+                    )
+                    ds //= 2
+                self.output_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+
+        self.out = nn.Sequential(
+            normalization(ch),
+            nn.SiLU(),
+            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+        )
+        if self.predict_codebook_ids:
+            self.id_predictor = nn.Sequential(
+                normalization(ch),
+                conv_nd(dims, model_channels, n_embed, 1),
+                # nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
+            )
+
+    def convert_to_fp16(self):
+        """
+        Convert the torso of the model to float16.
+        """
+        self.input_blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+        self.output_blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self):
+        """
+        Convert the torso of the model to float32.
+        """
+        self.input_blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+        self.output_blocks.apply(convert_module_to_f32)
+
+    def forward(
+        self,
+        x,
+        timesteps=None,
+        context=None,
+        y=None,
+        camera=None,
+        num_frames=1,
+        **kwargs,
+    ):
+        """
+        Apply the model to an input batch.
+        :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
+        :param timesteps: a 1-D batch of timesteps.
+        :param context: a dict conditioning plugged in via crossattn
+        :param y: an [N] Tensor of labels, if class-conditional, default None.
+        :param num_frames: a integer indicating number of frames for tensor reshaping.
+        :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
+        """
+        assert (
+            x.shape[0] % num_frames == 0
+        ), "[UNet] input batch size must be dividable by num_frames!"
+        assert (y is not None) == (
+            self.num_classes is not None
+        ), "must specify y if and only if the model is class-conditional"
+        
+        hs = []
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
+        emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
+
+        if self.num_classes is not None:
+            assert y.shape[0] == x.shape[0]
+            emb = emb + self.label_emb(y)
+            
+        # Add camera embeddings
+        if camera is not None:
+            assert camera.shape[0] == emb.shape[0]
+            # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
+            emb = emb + self.camera_embed(camera)
+        ip = kwargs.get("ip", None)
+        ip_img = kwargs.get("ip_img", None)
+
+        if ip_img is not None:
+            x[(num_frames-1)::num_frames, :, :, :] = ip_img
+            
+        if ip is not None:
+            ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
+            context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
+
+        h = x.type(self.dtype)
+        for module in self.input_blocks:
+            h = module(h, emb, context, num_frames=num_frames)
+            hs.append(h)
+        h = self.middle_block(h, emb, context, num_frames=num_frames)
+        for module in self.output_blocks:
+            h = th.cat([h, hs.pop()], dim=1)
+            h = module(h, emb, context, num_frames=num_frames)
+        h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
+        if self.predict_codebook_ids: # False
+            return self.id_predictor(h)
+        else:
+            return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
+        
+
+
+
+class MultiViewUNetModelStage2(MultiViewUNetModel):
+    """
+    The full multi-view UNet model with attention, timestep embedding and camera embedding.
+    :param in_channels: channels in the input Tensor.
+    :param model_channels: base channel count for the model.
+    :param out_channels: channels in the output Tensor.
+    :param num_res_blocks: number of residual blocks per downsample.
+    :param attention_resolutions: a collection of downsample rates at which
+        attention will take place. May be a set, list, or tuple.
+        For example, if this contains 4, then at 4x downsampling, attention
+        will be used.
+    :param dropout: the dropout probability.
+    :param channel_mult: channel multiplier for each level of the UNet.
+    :param conv_resample: if True, use learned convolutions for upsampling and
+        downsampling.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param num_classes: if specified (as an int), then this model will be
+        class-conditional with `num_classes` classes.
+    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+    :param num_heads: the number of attention heads in each attention layer.
+    :param num_heads_channels: if specified, ignore num_heads and instead use
+                               a fixed channel width per attention head.
+    :param num_heads_upsample: works with num_heads to set a different number
+                               of heads for upsampling. Deprecated.
+    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+    :param resblock_updown: use residual blocks for up/downsampling.
+    :param use_new_attention_order: use a different attention pattern for potentially
+                                    increased efficiency.
+    :param camera_dim: dimensionality of camera input.
+    """
+
+    def __init__(
+        self,
+        image_size,
+        in_channels,
+        model_channels,
+        out_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        num_classes=None,
+        use_checkpoint=False,
+        use_fp16=False,
+        use_bf16=False,
+        num_heads=-1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+        use_spatial_transformer=False,  # custom transformer support
+        transformer_depth=1,  # custom transformer support
+        context_dim=None,  # custom transformer support
+        n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
+        legacy=True,
+        disable_self_attentions=None,
+        num_attention_blocks=None,
+        disable_middle_self_attn=False,
+        use_linear_in_transformer=False,
+        adm_in_channels=None,
+        camera_dim=None,
+        with_ip=False,  # wether add image prompt images
+        ip_dim=0,  # number of extra token, 4 for global 16 for local
+        ip_weight=1.0,  # weight for image prompt context
+        ip_mode="local_resample", # which mode of adaptor, global or local
+    ):
+        super().__init__(
+            image_size,
+            in_channels,
+            model_channels,
+            out_channels,
+            num_res_blocks,
+            attention_resolutions,
+            dropout,
+            channel_mult,
+            conv_resample,
+            dims,
+            num_classes,
+            use_checkpoint,
+            use_fp16,
+            use_bf16,
+            num_heads,
+            num_head_channels,
+            num_heads_upsample,
+            use_scale_shift_norm,
+            resblock_updown,
+            use_new_attention_order,
+            use_spatial_transformer,
+            transformer_depth,
+            context_dim,
+            n_embed,
+            legacy,
+            disable_self_attentions,
+            num_attention_blocks,
+            disable_middle_self_attn,
+            use_linear_in_transformer,
+            adm_in_channels,
+            camera_dim,
+            with_ip,
+            ip_dim,
+            ip_weight,
+            ip_mode,
+        )
+
+    def forward(
+        self,
+        x,
+        timesteps=None,
+        context=None,
+        y=None,
+        camera=None,
+        num_frames=1,
+        **kwargs,
+    ):
+        """
+        Apply the model to an input batch.
+        :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
+        :param timesteps: a 1-D batch of timesteps.
+        :param context: a dict conditioning plugged in via crossattn
+        :param y: an [N] Tensor of labels, if class-conditional, default None.
+        :param num_frames: a integer indicating number of frames for tensor reshaping.
+        :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
+        """
+        assert (
+            x.shape[0] % num_frames == 0
+        ), "[UNet] input batch size must be dividable by num_frames!"
+        assert (y is not None) == (
+            self.num_classes is not None
+        ), "must specify y if and only if the model is class-conditional"
+        
+        hs = []
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00
+        emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51
+
+        if self.num_classes is not None:
+            assert y.shape[0] == x.shape[0]
+            emb = emb + self.label_emb(y)
+            
+        # Add camera embeddings
+        if camera is not None:
+            assert camera.shape[0] == emb.shape[0]
+            # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04
+            emb = emb + self.camera_embed(camera)
+        ip = kwargs.get("ip", None)
+        ip_img = kwargs.get("ip_img", None)
+        pixel_images = kwargs.get("pixel_images", None)
+
+        if ip_img is not None:
+            x[(num_frames-1)::num_frames, :, :, :] = ip_img
+            
+        x = torch.cat((x, pixel_images), dim=1)
+            
+        if ip is not None:
+            ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
+            context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31
+
+        h = x.type(self.dtype)
+        for module in self.input_blocks:
+            h = module(h, emb, context, num_frames=num_frames)
+            hs.append(h)
+        h = self.middle_block(h, emb, context, num_frames=num_frames)
+        for module in self.output_blocks:
+            h = th.cat([h, hs.pop()], dim=1)
+            h = module(h, emb, context, num_frames=num_frames)
+        h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58
+        if self.predict_codebook_ids: # False
+            return self.id_predictor(h)
+        else:
+            return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93
+        
\ No newline at end of file
diff --git a/imagedream/ldm/modules/diffusionmodules/util.py b/imagedream/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..665af3eeb177bb424179e7b9b06008c510533c5d
--- /dev/null
+++ b/imagedream/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,353 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+import importlib
+
+
+def instantiate_from_config(config):
+    if not "target" in config:
+        if config == "__is_first_stage__":
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
+def make_beta_schedule(
+    schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
+):
+    if schedule == "linear":
+        betas = (
+            torch.linspace(
+                linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
+            )
+            ** 2
+        )
+
+    elif schedule == "cosine":
+        timesteps = (
+            torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+        )
+        alphas = timesteps / (1 + cosine_s) * np.pi / 2
+        alphas = torch.cos(alphas).pow(2)
+        alphas = alphas / alphas[0]
+        betas = 1 - alphas[1:] / alphas[:-1]
+        betas = np.clip(betas, a_min=0, a_max=0.999)
+
+    elif schedule == "sqrt_linear":
+        betas = torch.linspace(
+            linear_start, linear_end, n_timestep, dtype=torch.float64
+        )
+    elif schedule == "sqrt":
+        betas = (
+            torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+            ** 0.5
+        )
+    else:
+        raise ValueError(f"schedule '{schedule}' unknown.")
+    return betas.numpy()
+
+def enforce_zero_terminal_snr(betas):
+    betas = torch.tensor(betas) if not isinstance(betas, torch.Tensor) else betas
+    # Convert betas to alphas_bar_sqrt
+    alphas =1 - betas
+    alphas_bar = alphas.cumprod(0)
+    alphas_bar_sqrt = alphas_bar.sqrt()
+    # Store old values.
+    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+    # Shift so last timestep is zero.
+    alphas_bar_sqrt -= alphas_bar_sqrt_T
+    # Scale so first timestep is back to old value.
+    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+    # Convert alphas_bar_sqrt to betas
+    alphas_bar = alphas_bar_sqrt ** 2
+    alphas = alphas_bar[1:] / alphas_bar[:-1]
+    alphas = torch.cat ([alphas_bar[0:1], alphas])
+    betas = 1 - alphas
+    return betas
+
+
+def make_ddim_timesteps(
+    ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
+):
+    if ddim_discr_method == "uniform":
+        c = num_ddpm_timesteps // num_ddim_timesteps
+        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+    elif ddim_discr_method == "quad":
+        ddim_timesteps = (
+            (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
+        ).astype(int)
+    else:
+        raise NotImplementedError(
+            f'There is no ddim discretization method called "{ddim_discr_method}"'
+        )
+
+    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+    # add one to get the final alpha values right (the ones from first scale to data during sampling)
+    steps_out = ddim_timesteps + 1
+    if verbose:
+        print(f"Selected timesteps for ddim sampler: {steps_out}")
+    return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+    # select alphas for computing the variance schedule
+    alphas = alphacums[ddim_timesteps]
+    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+    # according the the formula provided in https://arxiv.org/abs/2010.02502
+    sigmas = eta * np.sqrt(
+        (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
+    )
+    if verbose:
+        print(
+            f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
+        )
+        print(
+            f"For the chosen value of eta, which is {eta}, "
+            f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
+        )
+    return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+    """
+    Create a beta schedule that discretizes the given alpha_t_bar function,
+    which defines the cumulative product of (1-beta) over time from t = [0,1].
+    :param num_diffusion_timesteps: the number of betas to produce.
+    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+                      produces the cumulative product of (1-beta) up to that
+                      part of the diffusion process.
+    :param max_beta: the maximum beta to use; use values lower than 1 to
+                     prevent singularities.
+    """
+    betas = []
+    for i in range(num_diffusion_timesteps):
+        t1 = i / num_diffusion_timesteps
+        t2 = (i + 1) / num_diffusion_timesteps
+        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+    return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+    b, *_ = t.shape
+    out = a.gather(-1, t)
+    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+    """
+    Evaluate a function without caching intermediate activations, allowing for
+    reduced memory at the expense of extra compute in the backward pass.
+    :param func: the function to evaluate.
+    :param inputs: the argument sequence to pass to `func`.
+    :param params: a sequence of parameters `func` depends on but does not
+                   explicitly take as arguments.
+    :param flag: if False, disable gradient checkpointing.
+    """
+    if flag:
+        args = tuple(inputs) + tuple(params)
+        return CheckpointFunction.apply(func, len(inputs), *args)
+    else:
+        return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, run_function, length, *args):
+        ctx.run_function = run_function
+        ctx.input_tensors = list(args[:length])
+        ctx.input_params = list(args[length:])
+
+        with torch.no_grad():
+            output_tensors = ctx.run_function(*ctx.input_tensors)
+        return output_tensors
+
+    @staticmethod
+    def backward(ctx, *output_grads):
+        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+        with torch.enable_grad():
+            # Fixes a bug where the first op in run_function modifies the
+            # Tensor storage in place, which is not allowed for detach()'d
+            # Tensors.
+            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+            output_tensors = ctx.run_function(*shallow_copies)
+        input_grads = torch.autograd.grad(
+            output_tensors,
+            ctx.input_tensors + ctx.input_params,
+            output_grads,
+            allow_unused=True,
+        )
+        del ctx.input_tensors
+        del ctx.input_params
+        del output_tensors
+        return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+    """
+    Create sinusoidal timestep embeddings.
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    if not repeat_only:
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period)
+            * torch.arange(start=0, end=half, dtype=torch.float32)
+            / half
+        ).to(device=timesteps.device)
+        args = timesteps[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat(
+                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+            )
+    else:
+        embedding = repeat(timesteps, "b -> b d", d=dim)
+    # import pdb; pdb.set_trace()
+    return embedding
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def scale_module(module, scale):
+    """
+    Scale the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().mul_(scale)
+    return module
+
+
+def mean_flat(tensor):
+    """
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+    """
+    Make a standard normalization layer.
+    :param channels: number of input channels.
+    :return: an nn.Module for normalization.
+    """
+    return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+    def forward(self, x):
+        return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D convolution module.
+    """
+    if dims == 1:
+        return nn.Conv1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.Conv2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.Conv3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+    """
+    Create a linear module.
+    """
+    return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D average pooling module.
+    """
+    if dims == 1:
+        return nn.AvgPool1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.AvgPool2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.AvgPool3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+    def __init__(self, c_concat_config, c_crossattn_config):
+        super().__init__()
+        self.concat_conditioner = instantiate_from_config(c_concat_config)
+        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+    def forward(self, c_concat, c_crossattn):
+        c_concat = self.concat_conditioner(c_concat)
+        c_crossattn = self.crossattn_conditioner(c_crossattn)
+        return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
+        shape[0], *((1,) * (len(shape) - 1))
+    )
+    noise = lambda: torch.randn(shape, device=device)
+    return repeat_noise() if repeat else noise()
+
+
+# dummy replace
+def convert_module_to_f16(l):
+    """
+        Convert primitive modules to float16.
+    """
+    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+        l.weight.data = l.weight.data.half()
+        if l.bias is not None:
+            l.bias.data = l.bias.data.half()
+
+def convert_module_to_f32(l):
+    """
+    Convert primitive modules to float32, undoing convert_module_to_f16().
+    """
+    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+        l.weight.data = l.weight.data.float()
+        if l.bias is not None:
+            l.bias.data = l.bias.data.float()
diff --git a/imagedream/ldm/modules/distributions/__init__.py b/imagedream/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imagedream/ldm/modules/distributions/distributions.py b/imagedream/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b61f03077358ce4737c85842d9871f70dabb656
--- /dev/null
+++ b/imagedream/ldm/modules/distributions/distributions.py
@@ -0,0 +1,102 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+    def sample(self):
+        raise NotImplementedError()
+
+    def mode(self):
+        raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+    def __init__(self, value):
+        self.value = value
+
+    def sample(self):
+        return self.value
+
+    def mode(self):
+        return self.value
+
+
+class DiagonalGaussianDistribution(object):
+    def __init__(self, parameters, deterministic=False):
+        self.parameters = parameters
+        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+        self.deterministic = deterministic
+        self.std = torch.exp(0.5 * self.logvar)
+        self.var = torch.exp(self.logvar)
+        if self.deterministic:
+            self.var = self.std = torch.zeros_like(self.mean).to(
+                device=self.parameters.device
+            )
+
+    def sample(self):
+        x = self.mean + self.std * torch.randn(self.mean.shape).to(
+            device=self.parameters.device
+        )
+        return x
+
+    def kl(self, other=None):
+        if self.deterministic:
+            return torch.Tensor([0.0])
+        else:
+            if other is None:
+                return 0.5 * torch.sum(
+                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+                    dim=[1, 2, 3],
+                )
+            else:
+                return 0.5 * torch.sum(
+                    torch.pow(self.mean - other.mean, 2) / other.var
+                    + self.var / other.var
+                    - 1.0
+                    - self.logvar
+                    + other.logvar,
+                    dim=[1, 2, 3],
+                )
+
+    def nll(self, sample, dims=[1, 2, 3]):
+        if self.deterministic:
+            return torch.Tensor([0.0])
+        logtwopi = np.log(2.0 * np.pi)
+        return 0.5 * torch.sum(
+            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+            dim=dims,
+        )
+
+    def mode(self):
+        return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+    """
+    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+    Compute the KL divergence between two gaussians.
+    Shapes are automatically broadcasted, so batches can be compared to
+    scalars, among other use cases.
+    """
+    tensor = None
+    for obj in (mean1, logvar1, mean2, logvar2):
+        if isinstance(obj, torch.Tensor):
+            tensor = obj
+            break
+    assert tensor is not None, "at least one argument must be a Tensor"
+
+    # Force variances to be Tensors. Broadcasting helps convert scalars to
+    # Tensors, but it does not work for torch.exp().
+    logvar1, logvar2 = [
+        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+        for x in (logvar1, logvar2)
+    ]
+
+    return 0.5 * (
+        -1.0
+        + logvar2
+        - logvar1
+        + torch.exp(logvar1 - logvar2)
+        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+    )
diff --git a/imagedream/ldm/modules/ema.py b/imagedream/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..97b5ae2b230f89b4dba57e44c4f851478ad86f68
--- /dev/null
+++ b/imagedream/ldm/modules/ema.py
@@ -0,0 +1,86 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+    def __init__(self, model, decay=0.9999, use_num_upates=True):
+        super().__init__()
+        if decay < 0.0 or decay > 1.0:
+            raise ValueError("Decay must be between 0 and 1")
+
+        self.m_name2s_name = {}
+        self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+        self.register_buffer(
+            "num_updates",
+            torch.tensor(0, dtype=torch.int)
+            if use_num_upates
+            else torch.tensor(-1, dtype=torch.int),
+        )
+
+        for name, p in model.named_parameters():
+            if p.requires_grad:
+                # remove as '.'-character is not allowed in buffers
+                s_name = name.replace(".", "")
+                self.m_name2s_name.update({name: s_name})
+                self.register_buffer(s_name, p.clone().detach().data)
+
+        self.collected_params = []
+
+    def reset_num_updates(self):
+        del self.num_updates
+        self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
+
+    def forward(self, model):
+        decay = self.decay
+
+        if self.num_updates >= 0:
+            self.num_updates += 1
+            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+        one_minus_decay = 1.0 - decay
+
+        with torch.no_grad():
+            m_param = dict(model.named_parameters())
+            shadow_params = dict(self.named_buffers())
+
+            for key in m_param:
+                if m_param[key].requires_grad:
+                    sname = self.m_name2s_name[key]
+                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+                    shadow_params[sname].sub_(
+                        one_minus_decay * (shadow_params[sname] - m_param[key])
+                    )
+                else:
+                    assert not key in self.m_name2s_name
+
+    def copy_to(self, model):
+        m_param = dict(model.named_parameters())
+        shadow_params = dict(self.named_buffers())
+        for key in m_param:
+            if m_param[key].requires_grad:
+                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+            else:
+                assert not key in self.m_name2s_name
+
+    def store(self, parameters):
+        """
+        Save the current parameters for restoring later.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            temporarily stored.
+        """
+        self.collected_params = [param.clone() for param in parameters]
+
+    def restore(self, parameters):
+        """
+        Restore the parameters stored with the `store` method.
+        Useful to validate the model with EMA parameters without affecting the
+        original optimization process. Store the parameters before the
+        `copy_to` method. After validation (or model saving), use this to
+        restore the former parameters.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            updated with the stored parameters.
+        """
+        for c_param, param in zip(self.collected_params, parameters):
+            param.data.copy_(c_param.data)
diff --git a/imagedream/ldm/modules/encoders/__init__.py b/imagedream/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/imagedream/ldm/modules/encoders/modules.py b/imagedream/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..8efe14f3aeb29cd93203ea0f32a36f624c81aafd
--- /dev/null
+++ b/imagedream/ldm/modules/encoders/modules.py
@@ -0,0 +1,329 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+
+import numpy as np
+import open_clip
+from PIL import Image
+from ...util import default, count_params
+
+
+class AbstractEncoder(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def encode(self, *args, **kwargs):
+        raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+    def encode(self, x):
+        return x
+
+
+class ClassEmbedder(nn.Module):
+    def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
+        super().__init__()
+        self.key = key
+        self.embedding = nn.Embedding(n_classes, embed_dim)
+        self.n_classes = n_classes
+        self.ucg_rate = ucg_rate
+
+    def forward(self, batch, key=None, disable_dropout=False):
+        if key is None:
+            key = self.key
+        # this is for use in crossattn
+        c = batch[key][:, None]
+        if self.ucg_rate > 0.0 and not disable_dropout:
+            mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+            c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
+            c = c.long()
+        c = self.embedding(c)
+        return c
+
+    def get_unconditional_conditioning(self, bs, device="cuda"):
+        uc_class = (
+            self.n_classes - 1
+        )  # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+        uc = torch.ones((bs,), device=device) * uc_class
+        uc = {self.key: uc}
+        return uc
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+    """Uses the T5 transformer encoder for text"""
+
+    def __init__(
+        self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
+    ):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+        super().__init__()
+        self.tokenizer = T5Tokenizer.from_pretrained(version)
+        self.transformer = T5EncoderModel.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length  # TODO: typical value?
+        if freeze:
+            self.freeze()
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        # self.train = disabled_train
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        batch_encoding = self.tokenizer(
+            text,
+            truncation=True,
+            max_length=self.max_length,
+            return_length=True,
+            return_overflowing_tokens=False,
+            padding="max_length",
+            return_tensors="pt",
+        )
+        tokens = batch_encoding["input_ids"].to(self.device)
+        outputs = self.transformer(input_ids=tokens)
+
+        z = outputs.last_hidden_state
+        return z
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+    """Uses the CLIP transformer encoder for text (from huggingface)"""
+
+    LAYERS = ["last", "pooled", "hidden"]
+
+    def __init__(
+        self,
+        version="openai/clip-vit-large-patch14",
+        device="cuda",
+        max_length=77,
+        freeze=True,
+        layer="last",
+        layer_idx=None,
+    ):  # clip-vit-base-patch32
+        super().__init__()
+        assert layer in self.LAYERS
+        self.tokenizer = CLIPTokenizer.from_pretrained(version)
+        self.transformer = CLIPTextModel.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length
+        if freeze:
+            self.freeze()
+        self.layer = layer
+        self.layer_idx = layer_idx
+        if layer == "hidden":
+            assert layer_idx is not None
+            assert 0 <= abs(layer_idx) <= 12
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        # self.train = disabled_train
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        batch_encoding = self.tokenizer(
+            text,
+            truncation=True,
+            max_length=self.max_length,
+            return_length=True,
+            return_overflowing_tokens=False,
+            padding="max_length",
+            return_tensors="pt",
+        )
+        tokens = batch_encoding["input_ids"].to(self.device)
+        outputs = self.transformer(
+            input_ids=tokens, output_hidden_states=self.layer == "hidden"
+        )
+        if self.layer == "last":
+            z = outputs.last_hidden_state
+        elif self.layer == "pooled":
+            z = outputs.pooler_output[:, None, :]
+        else:
+            z = outputs.hidden_states[self.layer_idx]
+        return z
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
+    """
+    Uses the OpenCLIP transformer encoder for text
+    """
+
+    LAYERS = [
+        # "pooled",
+        "last",
+        "penultimate",
+    ]
+
+    def __init__(
+        self,
+        arch="ViT-H-14",
+        version="laion2b_s32b_b79k",
+        device="cuda",
+        max_length=77,
+        freeze=True,
+        layer="last",
+        ip_mode=None
+    ):
+        """_summary_
+
+        Args:
+            ip_mode (str, optional): what is the image promcessing mode. Defaults to None.
+
+        """
+        super().__init__()
+        assert layer in self.LAYERS
+        model, _, preprocess = open_clip.create_model_and_transforms(
+            arch, device=torch.device("cpu"), pretrained=version
+        )
+        if ip_mode is None:
+            del model.visual
+            
+        self.model = model
+        self.preprocess = preprocess
+        self.device = device
+        self.max_length = max_length
+        self.ip_mode = ip_mode
+        if freeze:
+            self.freeze()
+        self.layer = layer
+        if self.layer == "last":
+            self.layer_idx = 0
+        elif self.layer == "penultimate":
+            self.layer_idx = 1
+        else:
+            raise NotImplementedError()
+
+    def freeze(self):
+        self.model = self.model.eval()
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        tokens = open_clip.tokenize(text)
+        z = self.encode_with_transformer(tokens.to(self.device))
+        return z
+    
+    def forward_image(self, pil_image):
+        if isinstance(pil_image, Image.Image):
+            pil_image = [pil_image]
+        if isinstance(pil_image, torch.Tensor):
+            pil_image = pil_image.cpu().numpy()
+        if isinstance(pil_image, np.ndarray):
+            if pil_image.ndim == 3:
+                pil_image = pil_image[None, :, :, :]
+            pil_image = [Image.fromarray(x) for x in pil_image]
+
+        images = []
+        for image in pil_image:
+            images.append(self.preprocess(image).to(self.device))
+
+        image = torch.stack(images, 0) # to [b, 3, h, w]
+        if self.ip_mode == "global":
+            image_features = self.model.encode_image(image)
+            image_features /= image_features.norm(dim=-1, keepdim=True)
+        elif "local" in self.ip_mode:
+            image_features = self.encode_image_with_transformer(image)
+
+        return image_features # b, l
+    
+    def encode_image_with_transformer(self, x):
+        visual = self.model.visual
+        x = visual.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+
+        # class embeddings and positional embeddings
+        x = torch.cat(
+            [visual.class_embedding.to(x.dtype) + \
+             torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
+             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + visual.positional_embedding.to(x.dtype)
+
+        # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+        # x = visual.patch_dropout(x) 
+        x = visual.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        hidden = self.image_transformer_forward(x)
+        x = hidden[-2].permute(1, 0, 2)  # LND -> NLD
+        return x
+    
+    def image_transformer_forward(self, x):
+        encoder_states = ()
+        trans = self.model.visual.transformer
+        for r in trans.resblocks:
+            if trans.grad_checkpointing and not torch.jit.is_scripting():
+                # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
+                x = checkpoint(r, x, None, None, None)
+            else:
+                x = r(x, attn_mask=None)
+            encoder_states = encoder_states + (x, )
+        return encoder_states
+    
+    def encode_with_transformer(self, text):
+        x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]
+        x = x + self.model.positional_embedding
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.model.ln_final(x)
+        return x
+
+    def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+        for i, r in enumerate(self.model.transformer.resblocks):
+            if i == len(self.model.transformer.resblocks) - self.layer_idx:
+                break
+            if (
+                self.model.transformer.grad_checkpointing
+                and not torch.jit.is_scripting()
+            ):
+                x = checkpoint(r, x, attn_mask)
+            else:
+                x = r(x, attn_mask=attn_mask)
+        return x
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+    def __init__(
+        self,
+        clip_version="openai/clip-vit-large-patch14",
+        t5_version="google/t5-v1_1-xl",
+        device="cuda",
+        clip_max_length=77,
+        t5_max_length=77,
+    ):
+        super().__init__()
+        self.clip_encoder = FrozenCLIPEmbedder(
+            clip_version, device, max_length=clip_max_length
+        )
+        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+        print(
+            f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+            f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
+        )
+
+    def encode(self, text):
+        return self(text)
+
+    def forward(self, text):
+        clip_z = self.clip_encoder.encode(text)
+        t5_z = self.t5_encoder.encode(text)
+        return [clip_z, t5_z]
diff --git a/imagedream/ldm/util.py b/imagedream/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f541504a1181f6164b1084a6ffb0de365a9b06ba
--- /dev/null
+++ b/imagedream/ldm/util.py
@@ -0,0 +1,226 @@
+import importlib
+
+import random
+import torch
+import numpy as np
+from collections import abc
+
+import multiprocessing as mp
+from threading import Thread
+from queue import Queue
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+    # wh a tuple of (width, height)
+    # xc a list of captions to plot
+    b = len(xc)
+    txts = list()
+    for bi in range(b):
+        txt = Image.new("RGB", wh, color="white")
+        draw = ImageDraw.Draw(txt)
+        font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+        nc = int(40 * (wh[0] / 256))
+        lines = "\n".join(
+            xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
+        )
+
+        try:
+            draw.text((0, 0), lines, fill="black", font=font)
+        except UnicodeEncodeError:
+            print("Cant encode string for logging. Skipping.")
+
+        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+        txts.append(txt)
+    txts = np.stack(txts)
+    txts = torch.tensor(txts)
+    return txts
+
+
+def ismap(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+    """
+    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+    total_params = sum(p.numel() for p in model.parameters())
+    if verbose:
+        print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+    return total_params
+
+
+def instantiate_from_config(config):
+    if not "target" in config:
+        if config == "__is_first_stage__":
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    # import pdb; pdb.set_trace()
+    return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    # import pdb; pdb.set_trace()
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
+def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
+    # create dummy dataset instance
+
+    # run prefetching
+    if idx_to_fn:
+        res = func(data, worker_id=idx)
+    else:
+        res = func(data)
+    Q.put([idx, res])
+    Q.put("Done")
+
+
+def parallel_data_prefetch(
+    func: callable,
+    data,
+    n_proc,
+    target_data_type="ndarray",
+    cpu_intensive=True,
+    use_worker_id=False,
+):
+    # if target_data_type not in ["ndarray", "list"]:
+    #     raise ValueError(
+    #         "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
+    #     )
+    if isinstance(data, np.ndarray) and target_data_type == "list":
+        raise ValueError("list expected but function got ndarray.")
+    elif isinstance(data, abc.Iterable):
+        if isinstance(data, dict):
+            print(
+                f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+            )
+            data = list(data.values())
+        if target_data_type == "ndarray":
+            data = np.asarray(data)
+        else:
+            data = list(data)
+    else:
+        raise TypeError(
+            f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
+        )
+
+    if cpu_intensive:
+        Q = mp.Queue(1000)
+        proc = mp.Process
+    else:
+        Q = Queue(1000)
+        proc = Thread
+    # spawn processes
+    if target_data_type == "ndarray":
+        arguments = [
+            [func, Q, part, i, use_worker_id]
+            for i, part in enumerate(np.array_split(data, n_proc))
+        ]
+    else:
+        step = (
+            int(len(data) / n_proc + 1)
+            if len(data) % n_proc != 0
+            else int(len(data) / n_proc)
+        )
+        arguments = [
+            [func, Q, part, i, use_worker_id]
+            for i, part in enumerate(
+                [data[i : i + step] for i in range(0, len(data), step)]
+            )
+        ]
+    processes = []
+    for i in range(n_proc):
+        p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
+        processes += [p]
+
+    # start processes
+    print(f"Start prefetching...")
+    import time
+
+    start = time.time()
+    gather_res = [[] for _ in range(n_proc)]
+    try:
+        for p in processes:
+            p.start()
+
+        k = 0
+        while k < n_proc:
+            # get result
+            res = Q.get()
+            if res == "Done":
+                k += 1
+            else:
+                gather_res[res[0]] = res[1]
+
+    except Exception as e:
+        print("Exception: ", e)
+        for p in processes:
+            p.terminate()
+
+        raise e
+    finally:
+        for p in processes:
+            p.join()
+        print(f"Prefetching complete. [{time.time() - start} sec.]")
+
+    if target_data_type == "ndarray":
+        if not isinstance(gather_res[0], np.ndarray):
+            return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
+
+        # order outputs
+        return np.concatenate(gather_res, axis=0)
+    elif target_data_type == "list":
+        out = []
+        for r in gather_res:
+            out.extend(r)
+        return out
+    else:
+        return gather_res
+
+def set_seed(seed=None):
+    random.seed(seed)
+    np.random.seed(seed)
+    if seed is not None:
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+
+def add_random_background(image, bg_color=None):
+    bg_color = np.random.rand() * 255 if bg_color is None else bg_color
+    image = np.array(image)
+    rgb, alpha = image[..., :3], image[..., 3:]
+    alpha = alpha.astype(np.float32) / 255.0
+    image_new = rgb * alpha + bg_color * (1 - alpha)
+    return Image.fromarray(image_new.astype(np.uint8))
\ No newline at end of file
diff --git a/imagedream/model_zoo.py b/imagedream/model_zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d65529ab1b717e3d60a0d38e2fb9fea1abf0c006
--- /dev/null
+++ b/imagedream/model_zoo.py
@@ -0,0 +1,64 @@
+""" Utiliy functions to load pre-trained models more easily """
+import os
+import pkg_resources
+from omegaconf import OmegaConf
+
+import torch
+from huggingface_hub import hf_hub_download
+
+from imagedream.ldm.util import instantiate_from_config
+
+
+PRETRAINED_MODELS = {
+    "sd-v2.1-base-4view-ipmv": {
+        "config": "sd_v2_base_ipmv.yaml",
+        "repo_id": "Peng-Wang/ImageDream",
+        "filename": "sd-v2.1-base-4view-ipmv.pt",
+    },
+    "sd-v2.1-base-4view-ipmv-local": {
+        "config": "sd_v2_base_ipmv_local.yaml",
+        "repo_id": "Peng-Wang/ImageDream",
+        "filename": "sd-v2.1-base-4view-ipmv-local.pt",
+    },
+}
+
+
+def get_config_file(config_path):
+    cfg_file = pkg_resources.resource_filename(
+        "imagedream", os.path.join("configs", config_path)
+    )
+    if not os.path.exists(cfg_file):
+        raise RuntimeError(f"Config {config_path} not available!")
+    return cfg_file
+
+
+def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None):
+    if (config_path is not None) and (ckpt_path is not None):
+        config = OmegaConf.load(config_path)
+        model = instantiate_from_config(config.model)
+        model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
+        return model
+        
+    if not model_name in PRETRAINED_MODELS:
+        raise RuntimeError(
+            f"Model name {model_name} is not a pre-trained model. Available models are:\n- "
+            + "\n- ".join(PRETRAINED_MODELS.keys())
+        )
+    model_info = PRETRAINED_MODELS[model_name]
+
+    # Instiantiate the model
+    print(f"Loading model from config: {model_info['config']}")
+    config_file = get_config_file(model_info["config"])
+    config = OmegaConf.load(config_file)
+    model = instantiate_from_config(config.model)
+
+    # Load pre-trained checkpoint from huggingface
+    if not ckpt_path:
+        ckpt_path = hf_hub_download(
+            repo_id=model_info["repo_id"],
+            filename=model_info["filename"],
+            cache_dir=cache_dir,
+        )
+        print(f"Loading model from cache file: {ckpt_path}")
+    model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
+    return model
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8a1be926e2fb3381acf1dbf6772b0d393451d70
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,91 @@
+import numpy as np
+import torch
+import time
+import nvdiffrast.torch as dr
+from util.utils import get_tri
+import tempfile
+from mesh import Mesh
+import zipfile
+def generate3d(model, rgb, ccm, device):
+
+    color_tri = torch.from_numpy(rgb)/255
+    xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255
+    color = color_tri.permute(2,0,1)
+    xyz = xyz_tri.permute(2,0,1)
+
+
+    def get_imgs(color):
+        # color : [C, H, W*6]
+        color_list = []
+        color_list.append(color[:,:,256*5:256*(1+5)])
+        for i in range(0,5):
+            color_list.append(color[:,:,256*i:256*(1+i)])
+        return torch.stack(color_list, dim=0)# [6, C, H, W]
+    
+    triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]
+
+    color = get_imgs(color)
+    xyz = get_imgs(xyz)
+
+    color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
+    xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)
+
+    triplane = torch.cat([color,xyz],dim=1).to(device)
+    # 3D visualize
+    model.eval()
+    glctx = dr.RasterizeCudaContext()
+
+    if model.denoising == True:
+        tnew = 20
+        tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
+        noise_new = torch.randn_like(triplane) *0.5+0.5
+        triplane = model.scheduler.add_noise(triplane, noise_new, tnew)    
+        start_time = time.time()
+        with torch.no_grad():
+            triplane_feature2 = model.unet2(triplane,tnew)
+        end_time = time.time()
+        elapsed_time = end_time - start_time
+        print(f"unet takes {elapsed_time}s")
+    else:
+        triplane_feature2 = model.unet2(triplane)
+        
+
+    with torch.no_grad():
+        data_config = {
+            'resolution': [1024, 1024],
+            "triview_color": triplane_color.to(device),
+        }
+
+        verts, faces = model.decode(data_config, triplane_feature2)
+
+        data_config['verts'] = verts[0]
+        data_config['faces'] = faces
+        
+
+    from kiui.mesh_utils import clean_mesh
+    verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005)
+    data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
+    data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()
+
+    start_time = time.time()
+    with torch.no_grad():
+        mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
+        model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2)
+
+        mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z")
+        mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
+        mesh.write(mesh_path_glb+".glb")
+
+        # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb')
+        # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name
+        # mesh_obj2.export(mesh_path_obj2+".obj")
+
+        with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip:
+            myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj')
+            myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png')
+            myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl')
+
+    end_time = time.time()
+    elapsed_time = end_time - start_time
+    print(f"uv takes {elapsed_time}s")
+    return mesh_path_glb+".glb", mesh_path_obj+'.zip'
diff --git a/libs/base_utils.py b/libs/base_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9233461bd699b9df025a7a532ea72a7c023b30f0
--- /dev/null
+++ b/libs/base_utils.py
@@ -0,0 +1,84 @@
+import numpy as np
+import cv2
+import torch
+import numpy as np
+from PIL import Image
+
+        
+def instantiate_from_config(config):
+    if not "target" in config:
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+    import importlib
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+        
+def tensor_detail(t):
+    assert type(t) == torch.Tensor
+    print(f"shape: {t.shape} mean: {t.mean():.2f}, std: {t.std():.2f}, min: {t.min():.2f}, max: {t.max():.2f}")
+
+
+
+def drawRoundRec(draw, color, x, y, w, h, r):
+    drawObject = draw
+
+    '''Rounds'''
+    drawObject.ellipse((x, y, x + r, y + r), fill=color)
+    drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color)
+    drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color)
+    drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color)
+
+    '''rec.s'''
+    drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color)
+    drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color)
+
+
+def do_resize_content(original_image: Image, scale_rate):
+    # resize image content wile retain the original image size
+    if scale_rate != 1:
+        # Calculate the new size after rescaling
+        new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
+        # Resize the image while maintaining the aspect ratio
+        resized_image = original_image.resize(new_size)
+        # Create a new image with the original size and black background
+        padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
+        paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
+        padded_image.paste(resized_image, paste_position)
+        return padded_image
+    else:
+        return original_image
+
+def add_stroke(img, color=(255, 255, 255), stroke_radius=3):
+    # color in R, G, B format
+    if isinstance(img, Image.Image):
+        assert img.mode == "RGBA"
+        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2BGRA)
+    else:
+        assert img.shape[2] == 4
+    gray = img[:,:, 3]
+    ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY)
+    contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)  
+    res = cv2.drawContours(img, contours,-1, tuple(color)[::-1] + (255,), stroke_radius)
+    return Image.fromarray(cv2.cvtColor(res,cv2.COLOR_BGRA2RGBA))
+
+def make_blob(image_size=(512, 512), sigma=0.2):
+    """
+    make 2D blob image with:
+    I(x, y)=1-\exp \left(-\frac{(x-H / 2)^2+(y-W / 2)^2}{2 \sigma^2 HS}\right)
+    """
+    import numpy as np
+    H, W = image_size
+    x = np.arange(0, W, 1, float)
+    y = np.arange(0, H, 1, float)
+    x, y = np.meshgrid(x, y)
+    x0 = W // 2
+    y0 = H // 2
+    img = 1 - np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2 * H * W))
+    return (img * 255).astype(np.uint8)
\ No newline at end of file
diff --git a/libs/sample.py b/libs/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ed9a20bd52a626dcc50199c1636a0c87ae85f5
--- /dev/null
+++ b/libs/sample.py
@@ -0,0 +1,380 @@
+import numpy as np
+import torch
+from imagedream.camera_utils import get_camera_for_index
+from imagedream.ldm.util import set_seed, add_random_background
+from libs.base_utils import do_resize_content
+from imagedream.ldm.models.diffusion.ddim import DDIMSampler
+from torchvision import transforms as T
+
+
+class ImageDreamDiffusion:
+    def __init__(
+        self,
+        model,
+        device,
+        dtype,
+        mode,
+        num_frames,
+        camera_views,
+        ref_position,
+        random_background=False,
+        offset_noise=False,
+        resize_rate=1,
+        image_size=256,
+        seed=1234,
+    ) -> None:
+        assert mode in ["pixel", "local"]
+        size = image_size
+        self.seed = seed
+        batch_size = max(4, num_frames)
+
+        neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated,  obscure, unnatural colors, poor lighting, dull, and unclear."
+        uc = model.get_learned_conditioning([neg_texts]).to(device)
+        sampler = DDIMSampler(model)
+
+        # pre-compute camera matrices
+        camera = [get_camera_for_index(i).squeeze() for i in camera_views]
+        camera[ref_position] = torch.zeros_like(camera[ref_position])  # set ref camera to zero
+        camera = torch.stack(camera)
+        camera = camera.repeat(batch_size // num_frames, 1).to(device)
+
+        self.image_transform = T.Compose(
+            [
+                T.Resize((size, size)),
+                T.ToTensor(),
+                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+            ]
+        )
+        self.dtype = dtype
+        self.ref_position = ref_position
+        self.mode = mode
+        self.random_background = random_background
+        self.resize_rate = resize_rate
+        self.num_frames = num_frames
+        self.size = size
+        self.device = device
+        self.batch_size = batch_size
+        self.model = model
+        self.sampler = sampler
+        self.uc = uc
+        self.camera = camera
+        self.offset_noise = offset_noise
+
+    @staticmethod
+    def i2i(
+        model,
+        image_size,
+        prompt,
+        uc,
+        sampler,
+        ip=None,
+        step=20,
+        scale=5.0,
+        batch_size=8,
+        ddim_eta=0.0,
+        dtype=torch.float32,
+        device="cuda",
+        camera=None,
+        num_frames=4,
+        pixel_control=False,
+        transform=None,
+        offset_noise=False,
+    ):
+        """ The function supports additional image prompt.
+        Args:
+            model (_type_): the image dream model
+            image_size (_type_): size of diffusion output (standard 256)
+            prompt (_type_): text prompt for the image (prompt in type str)
+            uc (_type_): unconditional vector (tensor in shape [1, 77, 1024])
+            sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler
+            ip (Image, optional): the image prompt. Defaults to None.
+            step (int, optional): _description_. Defaults to 20.
+            scale (float, optional): _description_. Defaults to 7.5.
+            batch_size (int, optional): _description_. Defaults to 8.
+            ddim_eta (float, optional): _description_. Defaults to 0.0.
+            dtype (_type_, optional): _description_. Defaults to torch.float32.
+            device (str, optional): _description_. Defaults to "cuda".
+            camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
+            num_frames (int, optional): _num of frames (views) to generate
+            pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode
+            transform: Compose(
+                Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn)
+                ToTensor()
+                Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+            )
+        """
+        ip_raw = ip
+        if type(prompt) != list:
+            prompt = [prompt]
+        with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
+            c = model.get_learned_conditioning(prompt).to(
+                device
+            )  # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
+            c_ = {"context": c.repeat(batch_size, 1, 1)}  # batch_size
+            uc_ = {"context": uc.repeat(batch_size, 1, 1)}
+
+            if camera is not None:
+                c_["camera"] = uc_["camera"] = (
+                    camera  # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
+                )
+                c_["num_frames"] = uc_["num_frames"] = num_frames
+
+            if ip is not None:
+                ip_embed = model.get_learned_image_conditioning(ip).to(
+                    device
+                )  # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
+                ip_ = ip_embed.repeat(batch_size, 1, 1)
+                c_["ip"] = ip_
+                uc_["ip"] = torch.zeros_like(ip_)
+
+            if pixel_control:
+                assert camera is not None
+                ip = transform(ip).to(
+                    device
+                )  # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00
+                ip_img = model.get_first_stage_encoding(
+                    model.encode_first_stage(ip[None, :, :, :])
+                )  # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55
+                c_["ip_img"] = ip_img
+                uc_["ip_img"] = torch.zeros_like(ip_img)
+
+            shape = [4, image_size // 8, image_size // 8]  # [4, 32, 32]
+            if offset_noise:
+                ref = transform(ip_raw).to(device)
+                ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
+                ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
+                time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
+                x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
+
+            samples_ddim, _ = (
+                sampler.sample(  # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
+                    S=step,
+                    conditioning=c_,
+                    batch_size=batch_size,
+                    shape=shape,
+                    verbose=False,
+                    unconditional_guidance_scale=scale,
+                    unconditional_conditioning=uc_,
+                    eta=ddim_eta,
+                    x_T=x_T if offset_noise else None,
+                )
+            )
+
+            x_sample = model.decode_first_stage(samples_ddim)
+            x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+            x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
+
+        return list(x_sample.astype(np.uint8))
+
+    def diffuse(self, t, ip, n_test=2):
+        set_seed(self.seed)
+        ip = do_resize_content(ip, self.resize_rate)
+        if self.random_background:
+            ip = add_random_background(ip)
+
+        images = []
+        for _ in range(n_test):
+            img = self.i2i(
+                self.model,
+                self.size,
+                t,
+                self.uc,
+                self.sampler,
+                ip=ip,
+                step=50,
+                scale=5,
+                batch_size=self.batch_size,
+                ddim_eta=0.0,
+                dtype=self.dtype,
+                device=self.device,
+                camera=self.camera,
+                num_frames=self.num_frames,
+                pixel_control=(self.mode == "pixel"),
+                transform=self.image_transform,
+                offset_noise=self.offset_noise,
+            )
+            img = np.concatenate(img, 1)
+            img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1)
+            images.append(img)
+        set_seed()  # unset random and numpy seed
+        return images
+
+
+class ImageDreamDiffusionStage2:
+    def __init__(
+        self,
+        model,
+        device,
+        dtype,
+        num_frames,
+        camera_views,
+        ref_position,
+        random_background=False,
+        offset_noise=False,
+        resize_rate=1,
+        mode="pixel",
+        image_size=256,
+        seed=1234,
+    ) -> None:
+        assert mode in ["pixel", "local"]
+
+        size = image_size
+        self.seed = seed
+        batch_size = max(4, num_frames)
+
+        neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated,  obscure, unnatural colors, poor lighting, dull, and unclear."
+        uc = model.get_learned_conditioning([neg_texts]).to(device)
+        sampler = DDIMSampler(model)
+
+        # pre-compute camera matrices
+        camera = [get_camera_for_index(i).squeeze() for i in camera_views]
+        if ref_position is not None:
+            camera[ref_position] = torch.zeros_like(camera[ref_position])  # set ref camera to zero
+        camera = torch.stack(camera)
+        camera = camera.repeat(batch_size // num_frames, 1).to(device)
+
+        self.image_transform = T.Compose(
+            [
+                T.Resize((size, size)),
+                T.ToTensor(),
+                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+            ]
+        )
+
+        self.dtype = dtype
+        self.mode = mode
+        self.ref_position = ref_position
+        self.random_background = random_background
+        self.resize_rate = resize_rate
+        self.num_frames = num_frames
+        self.size = size
+        self.device = device
+        self.batch_size = batch_size
+        self.model = model
+        self.sampler = sampler
+        self.uc = uc
+        self.camera = camera
+        self.offset_noise = offset_noise
+
+    @staticmethod
+    def i2iStage2(
+        model,
+        image_size,
+        prompt,
+        uc,
+        sampler,
+        pixel_images,
+        ip=None,
+        step=20,
+        scale=5.0,
+        batch_size=8,
+        ddim_eta=0.0,
+        dtype=torch.float32,
+        device="cuda",
+        camera=None,
+        num_frames=4,
+        pixel_control=False,
+        transform=None,
+        offset_noise=False,
+    ):
+        ip_raw = ip
+        if type(prompt) != list:
+            prompt = [prompt]
+        with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype):
+            c = model.get_learned_conditioning(prompt).to(
+                device
+            )  # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05
+            c_ = {"context": c.repeat(batch_size, 1, 1)}  # batch_size
+            uc_ = {"context": uc.repeat(batch_size, 1, 1)}
+
+            if camera is not None:
+                c_["camera"] = uc_["camera"] = (
+                    camera  # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00
+                )
+                c_["num_frames"] = uc_["num_frames"] = num_frames
+
+            if ip is not None:
+                ip_embed = model.get_learned_image_conditioning(ip).to(
+                    device
+                )  # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12
+                ip_ = ip_embed.repeat(batch_size, 1, 1)
+                c_["ip"] = ip_
+                uc_["ip"] = torch.zeros_like(ip_)
+
+            if pixel_control:
+                assert camera is not None
+                
+            transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images])
+            latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images))
+
+            c_["pixel_images"] = latent_pixel_images
+            uc_["pixel_images"] = torch.zeros_like(latent_pixel_images)
+
+            shape = [4, image_size // 8, image_size // 8]  # [4, 32, 32]
+            if offset_noise:
+                ref = transform(ip_raw).to(device)
+                ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :]))
+                ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True)
+                time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device)
+                x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps)
+
+            samples_ddim, _ = (
+                sampler.sample(  # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43
+                    S=step,
+                    conditioning=c_,
+                    batch_size=batch_size,
+                    shape=shape,
+                    verbose=False,
+                    unconditional_guidance_scale=scale,
+                    unconditional_conditioning=uc_,
+                    eta=ddim_eta,
+                    x_T=x_T if offset_noise else None,
+                )
+            )
+            x_sample = model.decode_first_stage(samples_ddim)
+            x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
+            x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy()
+
+        return list(x_sample.astype(np.uint8))
+
+    @torch.no_grad()
+    def diffuse(self, t, ip, pixel_images, n_test=2):
+        set_seed(self.seed)
+        ip = do_resize_content(ip, self.resize_rate)
+        pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images]
+
+        if self.random_background:
+            bg_color = np.random.rand() * 255
+            ip = add_random_background(ip, bg_color)
+            pixel_images = [add_random_background(i, bg_color) for i in pixel_images]
+
+        images = []
+        for _ in range(n_test):
+            img = self.i2iStage2(
+                self.model,
+                self.size,
+                t,
+                self.uc,
+                self.sampler,
+                pixel_images=pixel_images,
+                ip=ip,
+                step=50,
+                scale=5,
+                batch_size=self.batch_size,
+                ddim_eta=0.0,
+                dtype=self.dtype,
+                device=self.device,
+                camera=self.camera,
+                num_frames=self.num_frames,
+                pixel_control=(self.mode == "pixel"),
+                transform=self.image_transform,
+                offset_noise=self.offset_noise,
+            )
+            img = np.concatenate(img, 1)
+            img = np.concatenate(
+                (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]),
+                axis=1,
+            )
+            images.append(img)
+        set_seed()  # unset random and numpy seed
+        return images
diff --git a/mesh.py b/mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..225a271023da1f5fc3e6a709ce465a6fa91afb5d
--- /dev/null
+++ b/mesh.py
@@ -0,0 +1,845 @@
+import os
+import cv2
+import torch
+import trimesh
+import numpy as np
+
+from kiui.op import safe_normalize, dot
+from kiui.typing import *
+
+class Mesh:
+    """
+    A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
+
+    Note:
+        This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
+    """
+    def __init__(
+        self,
+        v: Optional[Tensor] = None,
+        f: Optional[Tensor] = None,
+        vn: Optional[Tensor] = None,
+        fn: Optional[Tensor] = None,
+        vt: Optional[Tensor] = None,
+        ft: Optional[Tensor] = None,
+        vc: Optional[Tensor] = None, # vertex color
+        albedo: Optional[Tensor] = None,
+        metallicRoughness: Optional[Tensor] = None,
+        device: Optional[torch.device] = None,
+    ):
+        """Init a mesh directly using all attributes.
+
+        Args:
+            v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
+            f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
+            vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
+            fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
+            vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
+            ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
+            vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
+            albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
+            metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
+            device (Optional[torch.device]): torch device. Defaults to None.
+        """
+        self.device = device
+        self.v = v
+        self.vn = vn
+        self.vt = vt
+        self.f = f
+        self.fn = fn
+        self.ft = ft
+        # will first see if there is vertex color to use
+        self.vc = vc
+        # only support a single albedo image
+        self.albedo = albedo
+        # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
+        # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
+        self.metallicRoughness = metallicRoughness
+
+        self.ori_center = 0
+        self.ori_scale = 1
+
+    @classmethod
+    def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
+        """load mesh from path.
+
+        Args:
+            path (str): path to mesh file, supports ply, obj, glb.
+            clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
+            resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
+            renormal (bool, optional): re-calc the vertex normals. Defaults to True.
+            retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
+            bound (float, optional): bound to resize. Defaults to 0.9.
+            front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
+            device (torch.device, optional): torch device. Defaults to None.
+        
+        Note:
+            a ``device`` keyword argument can be provided to specify the torch device. 
+            If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
+
+        Returns:
+            Mesh: the loaded Mesh object.
+        """
+        # obj supports face uv
+        if path.endswith(".obj"):
+            mesh = cls.load_obj(path, **kwargs)
+        # trimesh only supports vertex uv, but can load more formats
+        else:
+            mesh = cls.load_trimesh(path, **kwargs)
+        
+        # clean
+        if clean:
+            from kiui.mesh_utils import clean_mesh
+            vertices = mesh.v.detach().cpu().numpy()
+            triangles = mesh.f.detach().cpu().numpy()
+            vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
+            mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
+            mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
+
+        print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
+        # auto-normalize
+        if resize:
+            mesh.auto_size(bound=bound)
+        # auto-fix normal
+        if renormal or mesh.vn is None:
+            mesh.auto_normal()
+            print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
+        # auto-fix texcoords
+        if retex or (mesh.albedo is not None and mesh.vt is None):
+            mesh.auto_uv(cache_path=path)
+            print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
+
+        # rotate front dir to +z
+        if front_dir != "+z":
+            # axis switch
+            if "-z" in front_dir:
+                T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
+            elif "+x" in front_dir:
+                T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
+            elif "-x" in front_dir:
+                T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
+            elif "+y" in front_dir:
+                T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
+            elif "-y" in front_dir:
+                T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
+            else:
+                T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
+            # rotation (how many 90 degrees)
+            if '1' in front_dir:
+                T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) 
+            elif '2' in front_dir:
+                T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) 
+            elif '3' in front_dir:
+                T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) 
+            mesh.v @= T
+            mesh.vn @= T
+
+        return mesh
+
+    # load from obj file
+    @classmethod
+    def load_obj(cls, path, albedo_path=None, device=None):
+        """load an ``obj`` mesh.
+
+        Args:
+            path (str): path to mesh.
+            albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
+            device (torch.device, optional): torch device. Defaults to None.
+        
+        Note: 
+            We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
+            The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
+
+        Returns:
+            Mesh: the loaded Mesh object.
+        """
+        assert os.path.splitext(path)[-1] == ".obj"
+
+        mesh = cls()
+
+        # device
+        if device is None:
+            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        mesh.device = device
+
+        # load obj
+        with open(path, "r") as f:
+            lines = f.readlines()
+
+        def parse_f_v(fv):
+            # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
+            # supported forms:
+            # f v1 v2 v3
+            # f v1/vt1 v2/vt2 v3/vt3
+            # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
+            # f v1//vn1 v2//vn2 v3//vn3
+            xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
+            xs.extend([-1] * (3 - len(xs)))
+            return xs[0], xs[1], xs[2]
+
+        vertices, texcoords, normals = [], [], []
+        faces, tfaces, nfaces = [], [], []
+        mtl_path = None
+
+        for line in lines:
+            split_line = line.split()
+            # empty line
+            if len(split_line) == 0:
+                continue
+            prefix = split_line[0].lower()
+            # mtllib
+            if prefix == "mtllib":
+                mtl_path = split_line[1]
+            # usemtl
+            elif prefix == "usemtl":
+                pass # ignored
+            # v/vn/vt
+            elif prefix == "v":
+                vertices.append([float(v) for v in split_line[1:]])
+            elif prefix == "vn":
+                normals.append([float(v) for v in split_line[1:]])
+            elif prefix == "vt":
+                val = [float(v) for v in split_line[1:]]
+                texcoords.append([val[0], 1.0 - val[1]])
+            elif prefix == "f":
+                vs = split_line[1:]
+                nv = len(vs)
+                v0, t0, n0 = parse_f_v(vs[0])
+                for i in range(nv - 2):  # triangulate (assume vertices are ordered)
+                    v1, t1, n1 = parse_f_v(vs[i + 1])
+                    v2, t2, n2 = parse_f_v(vs[i + 2])
+                    faces.append([v0, v1, v2])
+                    tfaces.append([t0, t1, t2])
+                    nfaces.append([n0, n1, n2])
+
+        mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
+        mesh.vt = (
+            torch.tensor(texcoords, dtype=torch.float32, device=device)
+            if len(texcoords) > 0
+            else None
+        )
+        mesh.vn = (
+            torch.tensor(normals, dtype=torch.float32, device=device)
+            if len(normals) > 0
+            else None
+        )
+
+        mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
+        mesh.ft = (
+            torch.tensor(tfaces, dtype=torch.int32, device=device)
+            if len(texcoords) > 0
+            else None
+        )
+        mesh.fn = (
+            torch.tensor(nfaces, dtype=torch.int32, device=device)
+            if len(normals) > 0
+            else None
+        )
+
+        # see if there is vertex color
+        use_vertex_color = False
+        if mesh.v.shape[1] == 6:
+            use_vertex_color = True
+            mesh.vc = mesh.v[:, 3:]
+            mesh.v = mesh.v[:, :3]
+            print(f"[load_obj] use vertex color: {mesh.vc.shape}")
+
+        # try to load texture image
+        if not use_vertex_color:
+            # try to retrieve mtl file
+            mtl_path_candidates = []
+            if mtl_path is not None:
+                mtl_path_candidates.append(mtl_path)
+                mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
+            mtl_path_candidates.append(path.replace(".obj", ".mtl"))
+
+            mtl_path = None
+            for candidate in mtl_path_candidates:
+                if os.path.exists(candidate):
+                    mtl_path = candidate
+                    break
+            
+            # if albedo_path is not provided, try retrieve it from mtl
+            metallic_path = None
+            roughness_path = None
+            if mtl_path is not None and albedo_path is None:
+                with open(mtl_path, "r") as f:
+                    lines = f.readlines()
+
+                for line in lines:
+                    split_line = line.split()
+                    # empty line
+                    if len(split_line) == 0:
+                        continue
+                    prefix = split_line[0]
+                    
+                    if "map_Kd" in prefix:
+                        # assume relative path!
+                        albedo_path = os.path.join(os.path.dirname(path), split_line[1])
+                        print(f"[load_obj] use texture from: {albedo_path}")
+                    elif "map_Pm" in prefix:
+                        metallic_path = os.path.join(os.path.dirname(path), split_line[1])
+                    elif "map_Pr" in prefix:
+                        roughness_path = os.path.join(os.path.dirname(path), split_line[1])
+                    
+            # still not found albedo_path, or the path doesn't exist
+            if albedo_path is None or not os.path.exists(albedo_path):
+                # init an empty texture
+                print(f"[load_obj] init empty albedo!")
+                # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
+                albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])  # default color
+            else:
+                albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
+                albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
+                albedo = albedo.astype(np.float32) / 255
+                print(f"[load_obj] load texture: {albedo.shape}")
+            
+            mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
+            
+            # try to load metallic and roughness
+            if metallic_path is not None and roughness_path is not None:
+                print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
+                metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
+                metallic = metallic.astype(np.float32) / 255
+                roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
+                roughness = roughness.astype(np.float32) / 255
+                metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
+
+                mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
+
+        return mesh
+
+    @classmethod
+    def load_trimesh(cls, path, device=None):
+        """load a mesh using ``trimesh.load()``.
+
+        Can load various formats like ``glb`` and serves as a fallback.
+
+        Note:
+            We will try to merge all meshes if the glb contains more than one, 
+            but **this may cause the texture to lose**, since we only support one texture image!
+
+        Args:
+            path (str): path to the mesh file.
+            device (torch.device, optional): torch device. Defaults to None.
+
+        Returns:
+            Mesh: the loaded Mesh object.
+        """
+        mesh = cls()
+
+        # device
+        if device is None:
+            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+        mesh.device = device
+
+        # use trimesh to load ply/glb
+        _data = trimesh.load(path)
+        if isinstance(_data, trimesh.Scene):
+            if len(_data.geometry) == 1:
+                _mesh = list(_data.geometry.values())[0]
+            else:
+                print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
+                _concat = []
+                # loop the scene graph and apply transform to each mesh
+                scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
+                for k, v in scene_graph.items():
+                    name = v['geometry']
+                    if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
+                        transform = v['transform']
+                        _concat.append(_data.geometry[name].apply_transform(transform))
+                _mesh = trimesh.util.concatenate(_concat)
+        else:
+            _mesh = _data
+        
+        if _mesh.visual.kind == 'vertex':
+            vertex_colors = _mesh.visual.vertex_colors
+            vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
+            mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
+            print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
+        elif _mesh.visual.kind == 'texture':
+            _material = _mesh.visual.material
+            if isinstance(_material, trimesh.visual.material.PBRMaterial):
+                texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
+                # load metallicRoughness if present
+                if _material.metallicRoughnessTexture is not None:
+                    metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
+                    mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
+            elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
+                texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
+            else:
+                raise NotImplementedError(f"material type {type(_material)} not supported!")
+            mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
+            print(f"[load_trimesh] load texture: {texture.shape}")
+        else:
+            texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
+            mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
+            print(f"[load_trimesh] failed to load texture.")
+
+        vertices = _mesh.vertices
+
+        try:
+            texcoords = _mesh.visual.uv
+            texcoords[:, 1] = 1 - texcoords[:, 1]
+        except Exception as e:
+            texcoords = None
+
+        try:
+            normals = _mesh.vertex_normals
+        except Exception as e:
+            normals = None
+
+        # trimesh only support vertex uv...
+        faces = tfaces = nfaces = _mesh.faces
+
+        mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
+        mesh.vt = (
+            torch.tensor(texcoords, dtype=torch.float32, device=device)
+            if texcoords is not None
+            else None
+        )
+        mesh.vn = (
+            torch.tensor(normals, dtype=torch.float32, device=device)
+            if normals is not None
+            else None
+        )
+
+        mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
+        mesh.ft = (
+            torch.tensor(tfaces, dtype=torch.int32, device=device)
+            if texcoords is not None
+            else None
+        )
+        mesh.fn = (
+            torch.tensor(nfaces, dtype=torch.int32, device=device)
+            if normals is not None
+            else None
+        )
+
+        return mesh
+
+    # sample surface (using trimesh)
+    def sample_surface(self, count: int):
+        """sample points on the surface of the mesh.
+
+        Args:
+            count (int): number of points to sample.
+
+        Returns:
+            torch.Tensor: the sampled points, float [count, 3].
+        """
+        _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
+        points, face_idx = trimesh.sample.sample_surface(_mesh, count)
+        points = torch.from_numpy(points).float().to(self.device)
+        return points
+
+    # aabb
+    def aabb(self):
+        """get the axis-aligned bounding box of the mesh.
+
+        Returns:
+            Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
+        """
+        return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
+
+    # unit size
+    @torch.no_grad()
+    def auto_size(self, bound=0.9):
+        """auto resize the mesh.
+
+        Args:
+            bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
+        """
+        vmin, vmax = self.aabb()
+        self.ori_center = (vmax + vmin) / 2
+        self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
+        self.v = (self.v - self.ori_center) * self.ori_scale
+
+    def auto_normal(self):
+        """auto calculate the vertex normals.
+        """
+        i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
+        v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
+
+        face_normals = torch.cross(v1 - v0, v2 - v0)
+
+        # Splat face normals to vertices
+        vn = torch.zeros_like(self.v)
+        vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
+        vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
+        vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
+
+        # Normalize, replace zero (degenerated) normals with some default value
+        vn = torch.where(
+            dot(vn, vn) > 1e-20,
+            vn,
+            torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
+        )
+        vn = safe_normalize(vn)
+
+        self.vn = vn
+        self.fn = self.f
+
+    def auto_uv(self, cache_path=None, vmap=True):
+        """auto calculate the uv coordinates.
+
+        Args:
+            cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
+            vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf). 
+                Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
+        """
+        # try to load cache
+        if cache_path is not None:
+            cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
+        if cache_path is not None and os.path.exists(cache_path):
+            data = np.load(cache_path)
+            vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
+        else:
+            import xatlas
+
+            v_np = self.v.detach().cpu().numpy()
+            f_np = self.f.detach().int().cpu().numpy()
+            atlas = xatlas.Atlas()
+            atlas.add_mesh(v_np, f_np)
+            chart_options = xatlas.ChartOptions()
+            # chart_options.max_iterations = 4
+            atlas.generate(chart_options=chart_options)
+            vmapping, ft_np, vt_np = atlas[0]  # [N], [M, 3], [N, 2]
+
+            # save to cache
+            if cache_path is not None:
+                np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
+        
+        vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
+        ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
+        self.vt = vt
+        self.ft = ft
+
+        if vmap:
+            vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
+            self.align_v_to_vt(vmapping)
+    
+    def align_v_to_vt(self, vmapping=None):
+        """ remap v/f and vn/fn to vt/ft.
+
+        Args:
+            vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
+        """
+        if vmapping is None:
+            ft = self.ft.view(-1).long()
+            f = self.f.view(-1).long()
+            vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
+            vmapping[ft] = f # scatter, randomly choose one if index is not unique
+
+        self.v = self.v[vmapping]
+        self.f = self.ft
+        
+        if self.vn is not None:
+            self.vn = self.vn[vmapping]
+            self.fn = self.ft
+
+    def to(self, device):
+        """move all tensor attributes to device.
+
+        Args:
+            device (torch.device): target device.
+
+        Returns:
+            Mesh: self.
+        """
+        self.device = device
+        for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
+            tensor = getattr(self, name)
+            if tensor is not None:
+                setattr(self, name, tensor.to(device))
+        return self
+    
+    def write(self, path):
+        """write the mesh to a path.
+
+        Args:
+            path (str): path to write, supports ply, obj and glb.
+        """
+        if path.endswith(".ply"):
+            self.write_ply(path)
+        elif path.endswith(".obj"):
+            self.write_obj(path)
+        elif path.endswith(".glb") or path.endswith(".gltf"):
+            self.write_glb(path)
+        else:
+            raise NotImplementedError(f"format {path} not supported!")
+    
+    def write_ply(self, path):
+        """write the mesh in ply format. Only for geometry!
+
+        Args:
+            path (str): path to write.
+        """
+
+        if self.albedo is not None:
+            print(f'[WARN] ply format does not support exporting texture, will ignore!')
+
+        v_np = self.v.detach().cpu().numpy()
+        f_np = self.f.detach().cpu().numpy()
+
+        _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
+        _mesh.export(path)
+
+
+    def write_glb(self, path):
+        """write the mesh in glb/gltf format.
+          This will create a scene with a single mesh.
+
+        Args:
+            path (str): path to write.
+        """
+
+        # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
+        if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
+            self.align_v_to_vt()
+
+        import pygltflib
+
+        f_np = self.f.detach().cpu().numpy().astype(np.uint32)
+        f_np_blob = f_np.flatten().tobytes()
+
+        v_np = self.v.detach().cpu().numpy().astype(np.float32)
+        v_np_blob = v_np.tobytes()
+
+        blob = f_np_blob + v_np_blob
+        byteOffset = len(blob)
+
+        # base mesh
+        gltf = pygltflib.GLTF2(
+            scene=0,
+            scenes=[pygltflib.Scene(nodes=[0])],
+            nodes=[pygltflib.Node(mesh=0)],
+            meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
+                # indices to accessors (0 is triangles)
+                attributes=pygltflib.Attributes(
+                    POSITION=1,
+                ),
+                indices=0,
+            )])],
+            buffers=[
+                pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
+            ],
+            # buffer view (based on dtype)
+            bufferViews=[
+                # triangles; as flatten (element) array
+                pygltflib.BufferView(
+                    buffer=0,
+                    byteLength=len(f_np_blob),
+                    target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
+                ),
+                # positions; as vec3 array
+                pygltflib.BufferView(
+                    buffer=0,
+                    byteOffset=len(f_np_blob),
+                    byteLength=len(v_np_blob),
+                    byteStride=12, # vec3
+                    target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
+                ),
+            ],
+            accessors=[
+                # 0 = triangles
+                pygltflib.Accessor(
+                    bufferView=0,
+                    componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
+                    count=f_np.size,
+                    type=pygltflib.SCALAR,
+                    max=[int(f_np.max())],
+                    min=[int(f_np.min())],
+                ),
+                # 1 = positions
+                pygltflib.Accessor(
+                    bufferView=1,
+                    componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
+                    count=len(v_np),
+                    type=pygltflib.VEC3,
+                    max=v_np.max(axis=0).tolist(),
+                    min=v_np.min(axis=0).tolist(),
+                ),
+            ],
+        )
+
+        # append texture info
+        if self.vt is not None:
+
+            vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
+            vt_np_blob = vt_np.tobytes()
+
+            albedo = self.albedo.detach().cpu().numpy()
+            albedo = (albedo * 255).astype(np.uint8)
+            albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
+            albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
+
+            # update primitive
+            gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
+            gltf.meshes[0].primitives[0].material = 0
+
+            # update materials
+            gltf.materials.append(pygltflib.Material(
+                pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
+                    baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
+                    metallicFactor=0.0,
+                    roughnessFactor=1.0,
+                ),
+                alphaMode=pygltflib.OPAQUE,
+                alphaCutoff=None,
+                doubleSided=True,
+            ))
+
+            gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
+            gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
+            gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
+
+            # update buffers
+            gltf.bufferViews.append(
+                # index = 2, texcoords; as vec2 array
+                pygltflib.BufferView(
+                    buffer=0,
+                    byteOffset=byteOffset,
+                    byteLength=len(vt_np_blob),
+                    byteStride=8, # vec2
+                    target=pygltflib.ARRAY_BUFFER,
+                )
+            )
+
+            gltf.accessors.append(
+                # 2 = texcoords
+                pygltflib.Accessor(
+                    bufferView=2,
+                    componentType=pygltflib.FLOAT,
+                    count=len(vt_np),
+                    type=pygltflib.VEC2,
+                    max=vt_np.max(axis=0).tolist(),
+                    min=vt_np.min(axis=0).tolist(),
+                )
+            )
+
+            blob += vt_np_blob 
+            byteOffset += len(vt_np_blob)
+
+            gltf.bufferViews.append(
+                # index = 3, albedo texture; as none target
+                pygltflib.BufferView(
+                    buffer=0,
+                    byteOffset=byteOffset,
+                    byteLength=len(albedo_blob),
+                )
+            )
+
+            blob += albedo_blob
+            byteOffset += len(albedo_blob)
+
+            gltf.buffers[0].byteLength = byteOffset
+
+            # append metllic roughness
+            if self.metallicRoughness is not None:
+                metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
+                metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
+                metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
+                metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
+
+                # update texture definition
+                gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
+                gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
+                gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
+
+                gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
+                gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
+                gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
+
+                # update buffers
+                gltf.bufferViews.append(
+                    # index = 4, metallicRoughness texture; as none target
+                    pygltflib.BufferView(
+                        buffer=0,
+                        byteOffset=byteOffset,
+                        byteLength=len(metallicRoughness_blob),
+                    )
+                )
+
+                blob += metallicRoughness_blob
+                byteOffset += len(metallicRoughness_blob)
+
+                gltf.buffers[0].byteLength = byteOffset
+
+            
+        # set actual data
+        gltf.set_binary_blob(blob)
+
+        # glb = b"".join(gltf.save_to_bytes())
+        gltf.save(path)
+
+
+    def write_obj(self, path):
+        """write the mesh in obj format. Will also write the texture and mtl files.
+
+        Args:
+            path (str): path to write.
+        """
+
+        mtl_path = path.replace(".obj", ".mtl")
+        albedo_path = path.replace(".obj", "_albedo.png")
+        metallic_path = path.replace(".obj", "_metallic.png")
+        roughness_path = path.replace(".obj", "_roughness.png")
+
+        v_np = self.v.detach().cpu().numpy()
+        vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
+        vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
+        f_np = self.f.detach().cpu().numpy()
+        ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
+        fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
+
+        with open(path, "w") as fp:
+            fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
+
+            for v in v_np:
+                fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
+
+            if vt_np is not None:
+                for v in vt_np:
+                    fp.write(f"vt {v[0]} {1 - v[1]} \n")
+
+            if vn_np is not None:
+                for v in vn_np:
+                    fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
+
+            fp.write(f"usemtl defaultMat \n")
+            for i in range(len(f_np)):
+                fp.write(
+                    f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
+                             {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
+                             {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
+                )
+
+        with open(mtl_path, "w") as fp:
+            fp.write(f"newmtl defaultMat \n")
+            fp.write(f"Ka 1 1 1 \n")
+            fp.write(f"Kd 1 1 1 \n")
+            fp.write(f"Ks 0 0 0 \n")
+            fp.write(f"Tr 1 \n")
+            fp.write(f"illum 1 \n")
+            fp.write(f"Ns 0 \n")
+            if self.albedo is not None:
+                fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
+            if self.metallicRoughness is not None:
+                # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
+                fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
+                fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
+
+        if self.albedo is not None:
+            albedo = self.albedo.detach().cpu().numpy()
+            albedo = (albedo * 255).astype(np.uint8)
+            cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
+        
+        if self.metallicRoughness is not None:
+            metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
+            metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
+            cv2.imwrite(metallic_path, metallicRoughness[..., 2])
+            cv2.imwrite(roughness_path, metallicRoughness[..., 1])
+
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b339e3ea9dac5482a6daed63b069e4d1eda000a8
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1 @@
+from model.crm.model import CRM
\ No newline at end of file
diff --git a/model/archs/__init__.py b/model/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/archs/decoders/__init__.py b/model/archs/decoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/model/archs/decoders/__init__.py
@@ -0,0 +1 @@
+
diff --git a/model/archs/decoders/shape_texture_net.py b/model/archs/decoders/shape_texture_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..706d95cb82e9295390a848f4534119fa735f52a7
--- /dev/null
+++ b/model/archs/decoders/shape_texture_net.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TetTexNet(nn.Module):
+    def __init__(self, plane_reso=64, padding=0.1, fea_concat=True):
+        super().__init__()
+        # self.c_dim = c_dim
+        self.plane_reso = plane_reso
+        self.padding = padding
+        self.fea_concat = fea_concat
+
+    def forward(self, rolled_out_feature, query):
+        # rolled_out_feature: rolled-out triplane feature
+        # query: queried xyz coordinates (should be scaled consistently to ptr cloud)
+
+        plane_reso = self.plane_reso
+
+        triplane_feature = dict()
+        triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso]
+        triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso]
+        triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:]
+
+        query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy')
+        query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz')
+        query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx')
+
+        if self.fea_concat:
+            query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1)
+        else:
+            query_feature = query_feature_xy + query_feature_yz + query_feature_zx
+
+        output = query_feature.permute(0, 2, 1)
+
+        return output
+
+    # uses values from plane_feature and pixel locations from vgrid to interpolate feature
+    def sample_plane_feature(self, query, plane_feature, plane):
+        # CYF note:
+        # for pretraining, query are uniformly sampled positions w.i. [-scale, scale]
+        # for training, query are essentially tetrahedra grid vertices, which are
+        # also within [-scale, scale] in the current version!
+        # xy range [-scale, scale]
+        if plane == 'xy':
+            xy = query[:, :, [0, 1]]
+        elif plane == 'yz':
+            xy = query[:, :, [1, 2]]
+        elif plane == 'zx':
+            xy = query[:, :, [2, 0]]
+        else:
+            raise ValueError("Error! Invalid plane type!")
+
+        xy = xy[:, :, None].float()
+        # not seem necessary to rescale the grid, because from
+        # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html,
+        # it specifies sampling locations normalized by plane_feature's spatial dimension,
+        # which is within [-scale, scale] as specified by encoder's calling of coordinate2index()
+        vgrid = 1.0 * xy
+        sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
+
+        return sampled_feat
diff --git a/model/archs/mlp_head.py b/model/archs/mlp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b87ff5b380bdd9b6e2f73cde42862080562eec3
--- /dev/null
+++ b/model/archs/mlp_head.py
@@ -0,0 +1,40 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SdfMlp(nn.Module):
+    def __init__(self, input_dim, hidden_dim=512, bias=True):
+        super().__init__()
+        self.input_dim = input_dim
+        self.hidden_dim = hidden_dim
+
+        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
+        self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
+        self.fc3 = nn.Linear(hidden_dim, 4, bias=bias)
+
+
+    def forward(self, input):
+        x = F.relu(self.fc1(input))
+        x = F.relu(self.fc2(x))
+        out = self.fc3(x)
+        return out
+
+
+class RgbMlp(nn.Module):
+    def __init__(self, input_dim, hidden_dim=512, bias=True):
+        super().__init__()
+        self.input_dim = input_dim
+        self.hidden_dim = hidden_dim
+
+        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
+        self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias)
+        self.fc3 = nn.Linear(hidden_dim, 3, bias=bias)
+
+    def forward(self, input):
+        x = F.relu(self.fc1(input))
+        x = F.relu(self.fc2(x))
+        out = self.fc3(x)
+
+        return out
+
+    
\ No newline at end of file
diff --git a/model/archs/unet.py b/model/archs/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43b72b45e1aba8ced1801c664f4afbcd7556682
--- /dev/null
+++ b/model/archs/unet.py
@@ -0,0 +1,53 @@
+'''
+Codes are from:
+https://github.com/jaxony/unet-pytorch/blob/master/model.py
+'''
+
+import torch
+import torch.nn as nn
+from diffusers import UNet2DModel
+import einops
+class UNetPP(nn.Module):
+    '''
+        Wrapper for UNet in diffusers
+    '''
+    def __init__(self, in_channels):
+        super(UNetPP, self).__init__()
+        self.in_channels = in_channels
+        self.unet = UNet2DModel(
+                sample_size=[256, 256*3],
+                in_channels=in_channels,
+                out_channels=32,
+                layers_per_block=2,
+                block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4),
+                down_block_types=(
+                    "DownBlock2D",
+                    "DownBlock2D",
+                    "DownBlock2D",
+                    "AttnDownBlock2D",
+                    "AttnDownBlock2D",
+                    "AttnDownBlock2D",
+                    "DownBlock2D",
+                ),
+                up_block_types=(
+                    "UpBlock2D",
+                    "AttnUpBlock2D",
+                    "AttnUpBlock2D",
+                    "AttnUpBlock2D",
+                    "UpBlock2D",
+                    "UpBlock2D",
+                    "UpBlock2D",
+                ),
+            )
+             
+        self.unet.enable_xformers_memory_efficient_attention()    
+        if in_channels > 12:
+            self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3]))
+
+    def forward(self, x, t=256):
+        learned_plane = self.learned_plane
+        if x.shape[1] < self.in_channels:
+            learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device)
+            x = torch.cat([x, learned_plane], dim = 1)
+        return self.unet(x, t).sample
+
diff --git a/model/crm/model.py b/model/crm/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..70a309283aeec7d976ed5c5f60de36c1710f51c4
--- /dev/null
+++ b/model/crm/model.py
@@ -0,0 +1,213 @@
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+
+import numpy as np
+
+
+from pathlib import Path
+import cv2
+import trimesh
+import nvdiffrast.torch as dr
+
+from model.archs.decoders.shape_texture_net import TetTexNet
+from model.archs.unet import UNetPP
+from util.renderer import Renderer
+from model.archs.mlp_head import SdfMlp, RgbMlp
+import xatlas
+
+
+class Dummy:
+    pass
+
+class CRM(nn.Module):
+    def __init__(self, specs):
+        super(CRM, self).__init__()
+
+        self.specs = specs
+        # configs
+        input_specs = specs["Input"]
+        self.input = Dummy()
+        self.input.scale = input_specs['scale']
+        self.input.resolution = input_specs['resolution']
+        self.tet_grid_size = input_specs['tet_grid_size']
+        self.camera_angle_num = input_specs['camera_angle_num']
+
+        self.arch = Dummy()
+        self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"]
+        self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"]
+
+        self.dec = Dummy()
+        self.dec.c_dim = specs["DecoderSpecs"]["c_dim"]
+        self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"]
+
+        self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex"
+
+        self.unet2 = UNetPP(in_channels=self.dec.c_dim)
+
+        mlp_chnl_s = 3 if self.arch.fea_concat else 1  # 3 for queried triplane feature concatenation
+        self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat)
+
+        if self.geo_type == "flex":
+            self.weightMlp = nn.Sequential(
+                            nn.Linear(mlp_chnl_s * 32 * 8, 512),
+                            nn.SiLU(),
+                            nn.Linear(512, 21))         
+            
+        self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) 
+        self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias)
+        self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num,
+                                 scale=self.input.scale, geo_type = self.geo_type)
+
+
+        self.spob = True if specs['Pretrain']['mode'] is None else False  # whether to add sphere
+        self.radius = specs['Pretrain']['radius']  # used when spob
+
+        self.denoising = True
+        from diffusers import DDIMScheduler
+        self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
+
+    def decode(self, data, triplane_feature2):
+        if self.geo_type == "flex":
+            tet_verts = self.renderer.flexicubes.verts.unsqueeze(0)
+            tet_indices = self.renderer.flexicubes.indices
+
+        dec_verts = self.decoder(triplane_feature2, tet_verts)
+        out = self.sdfMlp(dec_verts)
+
+        weight = None
+        if self.geo_type == "flex":
+            grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1)
+            grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1])
+            weight = self.weightMlp(grid_feat)
+            weight = weight * 0.1
+
+        pred_sdf, deformation = out[..., 0], out[..., 1:]
+        if self.spob:
+            pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1))
+
+        _, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight)
+        return verts[0].unsqueeze(0), faces[0].int()
+
+    def export_mesh(self, data, out_dir, ind, device=None, tri_fea_2 = None):
+        verts = data['verts']
+        faces = data['faces']
+
+        dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0))
+        colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy()
+        # Expect predicted colors value range from [-1, 1]
+        colors = (colors * 0.5 + 0.5).clip(0, 1)
+
+        verts = verts.squeeze().cpu().numpy()
+        faces = faces[..., [2, 1, 0]].squeeze().cpu().numpy()
+
+        # export the final mesh
+        with torch.no_grad():
+            mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault...
+            mesh.export(out_dir / f'{ind}.obj')
+
+    def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None):
+
+        mesh_v = data['verts'].squeeze().cpu().numpy()
+        mesh_pos_idx = data['faces'].squeeze().cpu().numpy()
+
+        def interpolate(attr, rast, attr_idx, rast_db=None):
+            return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db,
+                                  diff_attrs=None if rast_db is None else 'all')
+
+        vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx)
+
+        mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device)
+        mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device)
+
+        # Convert to tensors
+        indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)
+
+        uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device)
+        mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device)
+        # mesh_v_tex. ture
+        uv_clip = uvs[None, ...] * 2.0 - 1.0
+
+        # pad to four component coordinate
+        uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1)
+
+        # rasterize
+        rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res)
+
+        # Interpolate world space position
+        gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int())
+        mask = rast[..., 3:4] > 0
+
+        # return uvs, mesh_tex_idx, gb_pos, mask
+        gb_pos_unsqz = gb_pos.view(-1, 3)
+        mask_unsqz = mask.view(-1)
+        tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1
+
+        gb_mask_pos = gb_pos_unsqz[mask_unsqz]
+
+        gb_mask_pos = gb_mask_pos[None, ]
+
+        with torch.no_grad():
+
+            dec_verts = self.decoder(tri_fea_2, gb_mask_pos)
+            colors = self.rgbMlp(dec_verts).squeeze()
+
+        # Expect predicted colors value range from [-1, 1]
+        lo, hi = (-1, 1)
+        colors = (colors - lo) * (255 / (hi - lo))
+        colors = colors.clip(0, 255)
+
+        tex_unsqz[mask_unsqz] = colors
+
+        tex = tex_unsqz.view(res + (3,))
+
+        verts = mesh_v.squeeze().cpu().numpy()
+        faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy()
+        # faces = mesh_pos_idx
+        # faces = faces.detach().cpu().numpy()
+        # faces = faces[..., [2, 1, 0]]
+        indices = indices[..., [2, 1, 0]]
+
+        # xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs)
+        matname = f'{out_dir}.mtl'
+        # matname = f'{out_dir}/{ind}.mtl'
+        fid = open(matname, 'w')
+        fid.write('newmtl material_0\n')
+        fid.write('Kd 1 1 1\n')
+        fid.write('Ka 1 1 1\n')
+        # fid.write('Ks 0 0 0\n')
+        fid.write('Ks 0.4 0.4 0.4\n')
+        fid.write('Ns 10\n')
+        fid.write('illum 2\n')
+        fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n')
+        fid.close()
+
+        fid = open(f'{out_dir}.obj', 'w')
+        # fid = open(f'{out_dir}/{ind}.obj', 'w')
+        fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1])
+
+        for pidx, p in enumerate(verts):
+            pp = p
+            fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1]))
+
+        for pidx, p in enumerate(uvs):
+            pp = p
+            fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
+
+        fid.write('usemtl material_0\n')
+        for i, f in enumerate(faces):
+            f1 = f + 1
+            f2 = indices[i] + 1
+            fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
+        fid.close()
+
+        img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32)
+        mask = np.sum(img.astype(float), axis=-1, keepdims=True)
+        mask = (mask <= 3.0).astype(float)
+        kernel = np.ones((3, 3), 'uint8')
+        dilate_img = cv2.dilate(img, kernel, iterations=1)
+        img = img * (1 - mask) + dilate_img * mask
+        img = img.clip(0, 255).astype(np.uint8)
+
+        cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]])
+        # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]])
diff --git a/pipelines.py b/pipelines.py
new file mode 100644
index 0000000000000000000000000000000000000000..370eff5954f1850b18f626c487cbc55ad8ed98c2
--- /dev/null
+++ b/pipelines.py
@@ -0,0 +1,170 @@
+import torch
+from libs.base_utils import do_resize_content
+from imagedream.ldm.util import (
+    instantiate_from_config,
+    get_obj_from_str,
+)
+from omegaconf import OmegaConf
+from PIL import Image
+import numpy as np
+
+
+class TwoStagePipeline(object):
+    def __init__(
+        self,
+        stage1_model_config,
+        stage2_model_config,
+        stage1_sampler_config,
+        stage2_sampler_config,
+        device="cuda",
+        dtype=torch.float16,
+        resize_rate=1,
+    ) -> None:
+        """
+        only for two stage generate process.
+        - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
+        - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
+        """
+        self.resize_rate = resize_rate
+
+        self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
+        self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
+        self.stage1_model = self.stage1_model.to(device).to(dtype)
+
+        self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
+        sd = torch.load(stage2_model_config.resume, map_location="cpu")
+        self.stage2_model.load_state_dict(sd, strict=False)
+        self.stage2_model = self.stage2_model.to(device).to(dtype)
+
+        self.stage1_model.device = device
+        self.stage2_model.device = device
+        self.device = device
+        self.dtype = dtype
+        self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)(
+            self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params
+        )
+        self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)(
+            self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params
+        )
+
+    def stage1_sample(
+        self,
+        pixel_img,
+        prompt="3D assets",
+        neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated,  obscure, unnatural colors, poor lighting, dull, and unclear.",
+        step=50,
+        scale=5,
+        ddim_eta=0.0,
+    ):
+        if type(pixel_img) == str:
+            pixel_img = Image.open(pixel_img)
+
+        if isinstance(pixel_img, Image.Image):
+            if pixel_img.mode == "RGBA":
+                background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
+                pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
+            else:
+                pixel_img = pixel_img.convert("RGB")
+        else:
+            raise
+        uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device)
+        stage1_images = self.stage1_sampler.i2i(
+            self.stage1_sampler.model,
+            self.stage1_sampler.size,
+            prompt,
+            uc=uc,
+            sampler=self.stage1_sampler.sampler,
+            ip=pixel_img,
+            step=step,
+            scale=scale,
+            batch_size=self.stage1_sampler.batch_size,
+            ddim_eta=ddim_eta,
+            dtype=self.stage1_sampler.dtype,
+            device=self.stage1_sampler.device,
+            camera=self.stage1_sampler.camera,
+            num_frames=self.stage1_sampler.num_frames,
+            pixel_control=(self.stage1_sampler.mode == "pixel"),
+            transform=self.stage1_sampler.image_transform,
+            offset_noise=self.stage1_sampler.offset_noise,
+        )
+
+        stage1_images = [Image.fromarray(img) for img in stage1_images]
+        stage1_images.pop(self.stage1_sampler.ref_position)
+        return stage1_images
+
+    def stage2_sample(self, pixel_img, stage1_images):
+        if type(pixel_img) == str:
+            pixel_img = Image.open(pixel_img)
+
+        if isinstance(pixel_img, Image.Image):
+            if pixel_img.mode == "RGBA":
+                background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
+                pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
+            else:
+                pixel_img = pixel_img.convert("RGB")
+        else:
+            raise
+        stage2_images = self.stage2_sampler.i2iStage2(
+            self.stage2_sampler.model,
+            self.stage2_sampler.size,
+            "3D assets",
+            self.stage2_sampler.uc,
+            self.stage2_sampler.sampler,
+            pixel_images=stage1_images,
+            ip=pixel_img,
+            step=50,
+            scale=5,
+            batch_size=self.stage2_sampler.batch_size,
+            ddim_eta=0.0,
+            dtype=self.stage2_sampler.dtype,
+            device=self.stage2_sampler.device,
+            camera=self.stage2_sampler.camera,
+            num_frames=self.stage2_sampler.num_frames,
+            pixel_control=(self.stage2_sampler.mode == "pixel"),
+            transform=self.stage2_sampler.image_transform,
+            offset_noise=self.stage2_sampler.offset_noise,
+        )
+        stage2_images = [Image.fromarray(img) for img in stage2_images]
+        return stage2_images
+
+    def set_seed(self, seed):
+        self.stage1_sampler.seed = seed
+        self.stage2_sampler.seed = seed
+
+    def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
+        pixel_img = do_resize_content(pixel_img, self.resize_rate)
+        stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
+        stage2_images = self.stage2_sample(pixel_img, stage1_images)
+
+        return {
+            "ref_img": pixel_img,
+            "stage1_images": stage1_images,
+            "stage2_images": stage2_images,
+        }
+
+
+if __name__ == "__main__":
+
+    stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config
+    stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config
+    stage2_sampler_config = stage2_config.sampler
+    stage1_sampler_config = stage1_config.sampler
+
+    stage1_model_config = stage1_config.models
+    stage2_model_config = stage2_config.models
+
+    pipeline = TwoStagePipeline(
+        stage1_model_config,
+        stage2_model_config,
+        stage1_sampler_config,
+        stage2_sampler_config,
+    )
+
+    img = Image.open("assets/astronaut.png")
+    rt_dict = pipeline(img)
+    stage1_images = rt_dict["stage1_images"]
+    stage2_images = rt_dict["stage2_images"]
+    np_imgs = np.concatenate(stage1_images, 1)
+    np_xyzs = np.concatenate(stage2_images, 1)
+    Image.fromarray(np_imgs).save("pixel_images.png")
+    Image.fromarray(np_xyzs).save("xyz_images.png")
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ae91e22851faea6eeb50236a90403642a84d7fd1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+gradio
+huggingface-hub
+einops==0.7.0
+Pillow==10.1.0
+transformers==4.27.1
+open-clip-torch==2.7.0
+opencv-contrib-python-headless==4.9.0.80
+opencv-python-headless==4.9.0.80
+xformers
+omegaconf
+rembg
+nvdiffrast
+pygltflib
+kiui
\ No newline at end of file
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/util/flexicubes.py b/util/flexicubes.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a22d63cafc1f611e09b9a684041abb29ccbcaec
--- /dev/null
+++ b/util/flexicubes.py
@@ -0,0 +1,579 @@
+# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+import torch
+from util.tables import *
+
+__all__ = [
+    'FlexiCubes'
+]
+
+
+class FlexiCubes:
+    """
+    This class implements the FlexiCubes method for extracting meshes from scalar fields. 
+    It maintains a series of lookup tables and indices to support the mesh extraction process. 
+    FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances 
+    the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting 
+    the surface representation through gradient-based optimization.
+
+    During instantiation, the class loads DMC tables from a file and transforms them into 
+    PyTorch tensors on the specified device.
+
+    Attributes:
+        device (str): Specifies the computational device (default is "cuda").
+        dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges 
+            associated with each dual vertex in 256 Marching Cubes (MC) configurations.
+        num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of 
+            the 256 MC configurations.
+        check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 
+            of the DMC configurations.
+        tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.
+        quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles 
+            along one diagonal.
+        quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into 
+            two triangles along the other diagonal.
+        quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles 
+            during training by connecting all edges to their midpoints.
+        cube_corners (torch.Tensor): Defines the positions of a standard unit cube's 
+            eight corners in 3D space, ordered starting from the origin (0,0,0), 
+            moving along the x-axis, then y-axis, and finally z-axis. 
+            Used as a blueprint for generating a voxel grid.
+        cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used 
+            to retrieve the case id.
+        cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. 
+            Used to retrieve edge vertices in DMC.
+        edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with 
+            their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the 
+            first edge is oriented along the x-axis. 
+        dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges 
+            across four adjacent cubes to the shared faces of these cubes. For instance, 
+            dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along 
+            the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. 
+            This tensor is only utilized during isosurface tetrahedralization.
+        adj_pairs (torch.Tensor): 
+            A tensor containing index pairs that correspond to neighboring cubes that share the same edge.
+        qef_reg_scale (float):
+            The scaling factor applied to the regularization loss to prevent issues with singularity 
+            when solving the QEF. This parameter is only used when a 'grad_func' is specified.
+        weight_scale (float):
+            The scale of weights in FlexiCubes. Should be between 0 and 1.
+    """
+
+    def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
+
+        self.device = device
+        self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
+        self.num_vd_table = torch.tensor(num_vd_table,
+                                         dtype=torch.long, device=device, requires_grad=False)
+        self.check_table = torch.tensor(
+            check_table,
+            dtype=torch.long, device=device, requires_grad=False)
+
+        self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
+        self.quad_split_train = torch.tensor(
+            [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
+
+        self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
+                                         1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
+        self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
+        self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
+                                       2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
+
+        self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
+                                           dtype=torch.long, device=device)
+        self.dir_faces_table = torch.tensor([
+            [[5, 4], [3, 2], [4, 5], [2, 3]],
+            [[5, 4], [1, 0], [4, 5], [0, 1]],
+            [[3, 2], [1, 0], [2, 3], [0, 1]]
+        ], dtype=torch.long, device=device)
+        self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
+        self.qef_reg_scale = qef_reg_scale
+        self.weight_scale = weight_scale
+
+    def construct_voxel_grid(self, res):
+        """
+        Generates a voxel grid based on the specified resolution.
+
+        Args:
+            res (int or list[int]): The resolution of the voxel grid. If an integer
+                is provided, it is used for all three dimensions. If a list or tuple 
+                of 3 integers is provided, they define the resolution for the x, 
+                y, and z dimensions respectively.
+
+        Returns:
+            (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the 
+                cube corners (index into vertices) of the constructed voxel grid. 
+                The vertices are centered at the origin, with the length of each 
+                dimension in the grid being one.
+        """
+        base_cube_f = torch.arange(8).to(self.device)
+        if isinstance(res, int):
+            res = (res, res, res)
+        voxel_grid_template = torch.ones(res, device=self.device)
+
+        res = torch.tensor([res], dtype=torch.float, device=self.device)
+        coords = torch.nonzero(voxel_grid_template).float() / res  # N, 3
+        verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)
+        cubes = (base_cube_f.unsqueeze(0) +
+                 torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)
+
+        verts_rounded = torch.round(verts * 10**5) / (10**5)
+        verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)
+        cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)
+
+        return verts_unique - 0.5, cubes
+
+    def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,
+                 gamma_f=None, training=False, output_tetmesh=False, grad_func=None):
+        r"""
+        Main function for mesh extraction from scalar field using FlexiCubes. This function converts 
+        discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, 
+        to triangle or tetrahedral meshes using a differentiable operation as described in 
+        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances 
+        mesh quality and geometric fidelity by adjusting the surface representation based on gradient 
+        optimization. The output surface is differentiable with respect to the input vertex positions, 
+        scalar field values, and weight parameters.
+
+        If you intend to extract a surface mesh from a fixed Signed Distance Field without the 
+        optimization of parameters, it is suggested to provide the "grad_func" which should 
+        return the surface gradient at any given 3D position. When grad_func is provided, the process 
+        to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as 
+        described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. 
+        Please note, this approach is non-differentiable.
+
+        For more details and example usage in optimization, refer to the 
+        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.
+
+        Args:
+            x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.
+            s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values 
+                denote that the corresponding vertex resides inside the isosurface. This affects 
+                the directions of the extracted triangle faces and volume to be tetrahedralized.
+            cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.
+            res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it 
+                is used for all three dimensions. If a list or tuple of 3 integers is provided, they 
+                specify the resolution for the x, y, and z dimensions respectively.
+            beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual 
+                vertices positioning. Defaults to uniform value for all edges.
+            alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual 
+                vertices positioning. Defaults to uniform value for all vertices.
+            gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of 
+                quadrilaterals into triangles. Defaults to uniform value for all cubes.
+            training (bool, optional): If set to True, applies differentiable quad splitting for 
+                training. Defaults to False.
+            output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, 
+                outputs a triangular mesh. Defaults to False.
+            grad_func (callable, optional): A function to compute the surface gradient at specified 
+                3D positions (input: Nx3 positions). The function should return gradients as an Nx3 
+                tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.
+
+        Returns:
+            (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:
+                - Vertices for the extracted triangular/tetrahedral mesh.
+                - Faces for the extracted triangular/tetrahedral mesh.
+                - Regularizer L_dev, computed per dual vertex.
+
+        .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:
+            https://research.nvidia.com/labs/toronto-ai/flexicubes/
+        .. _Manifold Dual Contouring:
+            https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf
+        """
+
+        surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)
+        if surf_cubes.sum() == 0:
+            return torch.zeros(
+                (0, 3),
+                device=self.device), torch.zeros(
+                (0, 4),
+                dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(
+                (0, 3),
+                dtype=torch.long, device=self.device), torch.zeros(
+                (0),
+                device=self.device)
+        beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)
+
+        case_ids = self._get_case_id(occ_fx8, surf_cubes, res)
+
+        surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)
+
+        vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(
+            x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)
+        vertices, faces, s_edges, edge_indices = self._triangulate(
+            s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)
+        if not output_tetmesh:
+            return vertices, faces, L_dev
+        else:
+            vertices, tets = self._tetrahedralize(
+                x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
+                surf_cubes, training)
+            return vertices, tets, L_dev
+
+    def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
+        """
+        Regularizer L_dev as in Equation 8
+        """
+        dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
+        mean_l2 = torch.zeros_like(vd[:, 0])
+        mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
+        mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
+        return mad
+
+    def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
+        """
+        Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
+        """
+        n_cubes = surf_cubes.shape[0]
+
+        if beta_fx12 is not None:
+            beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)
+        else:
+            beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
+
+        if alpha_fx8 is not None:
+            alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)
+        else:
+            alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
+
+        if gamma_f is not None:
+            gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2
+        else:
+            gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
+
+        return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]
+
+    @torch.no_grad()
+    def _get_case_id(self, occ_fx8, surf_cubes, res):
+        """
+        Obtains the ID of topology cases based on cell corner occupancy. This function resolves the 
+        ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the 
+        supplementary material. It should be noted that this function assumes a regular grid.
+        """
+        case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
+
+        problem_config = self.check_table.to(self.device)[case_ids]
+        to_check = problem_config[..., 0] == 1
+        problem_config = problem_config[to_check]
+        if not isinstance(res, (list, tuple)):
+            res = [res, res, res]
+
+        # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
+        # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
+        # This allows efficient checking on adjacent cubes.
+        problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
+        vol_idx = torch.nonzero(problem_config_full[..., 0] == 0)  # N, 3
+        vol_idx_problem = vol_idx[surf_cubes][to_check]
+        problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
+        vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
+
+        within_range = (
+            vol_idx_problem_adj[..., 0] >= 0) & (
+            vol_idx_problem_adj[..., 0] < res[0]) & (
+            vol_idx_problem_adj[..., 1] >= 0) & (
+            vol_idx_problem_adj[..., 1] < res[1]) & (
+            vol_idx_problem_adj[..., 2] >= 0) & (
+            vol_idx_problem_adj[..., 2] < res[2])
+
+        vol_idx_problem = vol_idx_problem[within_range]
+        vol_idx_problem_adj = vol_idx_problem_adj[within_range]
+        problem_config = problem_config[within_range]
+        problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
+                                                 vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
+        # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
+        to_invert = (problem_config_adj[..., 0] == 1)
+        idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
+        case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
+        return case_ids
+
+    @torch.no_grad()
+    def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
+        """
+        Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge 
+        can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge 
+        and marks the cube edges with this index.
+        """
+        occ_n = s_n < 0
+        all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)
+        unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
+
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
+
+        surf_edges_mask = mask_edges[_idx_map]
+        counts = counts[_idx_map]
+
+        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)
+        # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
+        # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
+        idx_map = mapping[_idx_map]
+        surf_edges = unique_edges[mask_edges]
+        return surf_edges, idx_map, counts, surf_edges_mask
+
+    @torch.no_grad()
+    def _identify_surf_cubes(self, s_n, cube_fx8):
+        """
+        Identifies grid cubes that intersect with the underlying surface by checking if the signs at 
+        all corners are not identical.
+        """
+        occ_n = s_n < 0
+        occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
+        _occ_sum = torch.sum(occ_fx8, -1)
+        surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
+        return surf_cubes, occ_fx8
+
+    def _linear_interp(self, edges_weight, edges_x):
+        """
+        Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
+        """
+        edge_dim = edges_weight.dim() - 2
+        assert edges_weight.shape[edge_dim] == 2
+        edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
+                                 torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)
+        denominator = edges_weight.sum(edge_dim)
+        ue = (edges_x * edges_weight).sum(edge_dim) / denominator
+        return ue
+
+    def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
+        p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
+        norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
+        c_bx3 = c_bx3.reshape(-1, 3)
+        A = norm_bxnx3
+        B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
+
+        A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
+        B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)
+        A = torch.cat([A, A_reg], 1)
+        B = torch.cat([B, B_reg], 1)
+        dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
+        return dual_verts
+
+    def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):
+        """
+        Computes the location of dual vertices as described in Section 4.2
+        """
+        alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
+        surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
+        surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
+        zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
+
+        idx_map = idx_map.reshape(-1, 12)
+        num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
+        edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
+
+        total_num_vd = 0
+        vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
+        if grad_func is not None:
+            normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)
+            vd = []
+        for num in torch.unique(num_vd):
+            cur_cubes = (num_vd == num)  # consider cubes with the same numbers of vd emitted (for batching)
+            curr_num_vd = cur_cubes.sum() * num
+            curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
+            curr_edge_group_to_vd = torch.arange(
+                curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
+            total_num_vd += curr_num_vd
+            curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
+                cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
+
+            curr_mask = (curr_edge_group != -1)
+            edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
+            edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
+            edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
+            vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
+            vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
+
+            if grad_func is not None:
+                with torch.no_grad():
+                    cube_e_verts_idx = idx_map[cur_cubes]
+                    curr_edge_group[~curr_mask] = 0
+
+                    verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)
+                    verts_group_idx[verts_group_idx == -1] = 0
+                    verts_group_pos = torch.index_select(
+                        input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)
+                    v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)
+                    curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)
+                    verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))
+
+                    normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(
+                        -1, num.item(), 7,
+                        3)
+                    curr_mask = curr_mask.squeeze(2)
+                    vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,
+                                                 verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))
+        edge_group = torch.cat(edge_group)
+        edge_group_to_vd = torch.cat(edge_group_to_vd)
+        edge_group_to_cube = torch.cat(edge_group_to_cube)
+        vd_num_edges = torch.cat(vd_num_edges)
+        vd_gamma = torch.cat(vd_gamma)
+
+        if grad_func is not None:
+            vd = torch.cat(vd)
+            L_dev = torch.zeros([1], device=self.device)
+        else:
+            vd = torch.zeros((total_num_vd, 3), device=self.device)
+            beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
+
+            idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
+
+            x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
+            s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
+
+            zero_crossing_group = torch.index_select(
+                input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
+
+            alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
+                                             index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
+            ue_group = self._linear_interp(s_group * alpha_group, x_group)
+
+            beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,
+                                      index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
+            beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
+            vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
+            L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
+
+        v_idx = torch.arange(vd.shape[0], device=self.device)  # + total_num_vd
+
+        vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
+                                                      12 + edge_group, src=v_idx[edge_group_to_vd])
+
+        return vd, L_dev, vd_gamma, vd_idx_map
+
+    def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):
+        """
+        Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into 
+        triangles based on the gamma parameter, as described in Section 4.3.
+        """
+        with torch.no_grad():
+            group_mask = (edge_counts == 4) & surf_edges_mask  # surface edges shared by 4 cubes.
+            group = idx_map.reshape(-1)[group_mask]
+            vd_idx = vd_idx_map[group_mask]
+            edge_indices, indices = torch.sort(group, stable=True)
+            quad_vd_idx = vd_idx[indices].reshape(-1, 4)
+
+            # Ensure all face directions point towards the positive SDF to maintain consistent winding.
+            s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
+            flip_mask = s_edges[:, 0] > 0
+            quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
+                                     quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
+        if grad_func is not None:
+            # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.
+            with torch.no_grad():
+                vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)
+                quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
+                gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)
+                gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)
+        else:
+            quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
+            gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(
+                0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)
+            gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(
+                1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)
+        if not training:
+            mask = (gamma_02 > gamma_13).squeeze(1)
+            faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
+            faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
+            faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
+            faces = faces.reshape(-1, 3)
+        else:
+            vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
+            vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +
+                     torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2
+            vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +
+                     torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2
+            weight_sum = (gamma_02 + gamma_13) + 1e-8
+            vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /
+                         weight_sum.unsqueeze(-1)).squeeze(1)
+            vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
+            vd = torch.cat([vd, vd_center])
+            faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
+            faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
+        return vd, faces, s_edges, edge_indices
+
+    def _tetrahedralize(
+            self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,
+            surf_cubes, training):
+        """
+        Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.
+        """
+        occ_n = s_n < 0
+        occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)
+        occ_sum = torch.sum(occ_fx8, -1)
+
+        inside_verts = x_nx3[occ_n]
+        mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1
+        mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]
+        """ 
+        For each grid edge connecting two grid vertices with different
+        signs, we first form a four-sided pyramid by connecting one
+        of the grid vertices with four mesh vertices that correspond
+        to the grid edge and then subdivide the pyramid into two tetrahedra
+        """
+        inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[
+            s_edges < 0]]
+        if not training:
+            inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)
+        else:
+            inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)
+
+        tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)
+        """ 
+        For each grid edge connecting two grid vertices with the
+        same sign, the tetrahedron is formed by the two grid vertices
+        and two vertices in consecutive adjacent cells
+        """
+        inside_cubes = (occ_sum == 8)
+        inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)
+        inside_cubes_center_idx = torch.arange(
+            inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]
+
+        surface_n_inside_cubes = surf_cubes | inside_cubes
+        edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),
+                                            dtype=torch.long, device=x_nx3.device) * -1
+        surf_cubes = surf_cubes[surface_n_inside_cubes]
+        inside_cubes = inside_cubes[surface_n_inside_cubes]
+        edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)
+        edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx
+
+        all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)
+        unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
+        unique_edges = unique_edges.long()
+        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2
+        mask = mask_edges[_idx_map]
+        counts = counts[_idx_map]
+        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1
+        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)
+        idx_map = mapping[_idx_map]
+
+        group_mask = (counts == 4) & mask
+        group = idx_map.reshape(-1)[group_mask]
+        edge_indices, indices = torch.sort(group)
+        cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,
+                                device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]
+        edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(
+            0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]
+        # Identify the face shared by the adjacent cells.
+        cube_idx_4 = cube_idx[indices].reshape(-1, 4)
+        edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]
+        shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)
+        cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)
+        # Identify an edge of the face with different signs and
+        # select the mesh vertex corresponding to the identified edge.
+        case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255
+        case_ids_expand[surf_cubes] = case_ids
+        cases = case_ids_expand[cube_idx_4x2]
+        quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)
+        mask = (quad_edge == -1).sum(-1) == 0
+        inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)
+        tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]
+
+        tets = torch.cat([tets_surface, tets_inside])
+        vertices = torch.cat([vertices, inside_verts, inside_cubes_center])
+        return vertices, tets
diff --git a/util/flexicubes_geometry.py b/util/flexicubes_geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..055805f3ad1396cef768e02e2211b4b3b5601c82
--- /dev/null
+++ b/util/flexicubes_geometry.py
@@ -0,0 +1,116 @@
+# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+
+import torch
+from util.flexicubes import FlexiCubes # replace later
+# from dmtet import sdf_reg_loss_batch
+import torch.nn.functional as F
+
+def get_center_boundary_index(grid_res, device):
+    v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device)
+    v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True
+    center_indices = torch.nonzero(v.reshape(-1))
+
+    v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False
+    v[:2, ...] = True
+    v[-2:, ...] = True
+    v[:, :2, ...] = True
+    v[:, -2:, ...] = True
+    v[:, :, :2] = True
+    v[:, :, -2:] = True
+    boundary_indices = torch.nonzero(v.reshape(-1))
+    return center_indices, boundary_indices
+
+###############################################################################
+#  Geometry interface
+###############################################################################
+class FlexiCubesGeometry(object):
+    def __init__(
+            self, grid_res=64, scale=2.0, device='cuda', renderer=None,
+            render_type='neural_render', args=None):
+        super(FlexiCubesGeometry, self).__init__()
+        self.grid_res = grid_res
+        self.device = device
+        self.args = args
+        self.fc = FlexiCubes(device, weight_scale=0.5)
+        self.verts, self.indices = self.fc.construct_voxel_grid(grid_res)
+        if isinstance(scale, list):
+            self.verts[:, 0] = self.verts[:, 0] * scale[0]
+            self.verts[:, 1] = self.verts[:, 1] * scale[1]
+            self.verts[:, 2] = self.verts[:, 2] * scale[1]
+        else:
+            self.verts = self.verts * scale
+            
+        all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2)
+        self.all_edges = torch.unique(all_edges, dim=0)
+
+        # Parameters used for fix boundary sdf
+        self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device)
+        self.renderer = renderer
+        self.render_type = render_type
+
+    def getAABB(self):
+        return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values
+
+    def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False):
+        if indices is None:
+            indices = self.indices
+
+        verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res,
+                                            beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20],
+                                            gamma_f=weight_n[:, 20], training=is_training
+                                            )
+        return verts, faces, v_reg_loss
+
+
+    def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False):
+        return_value = dict()
+        if self.render_type == 'neural_render':
+            tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh(
+                mesh_v_nx3.unsqueeze(dim=0),
+                mesh_f_fx3.int(),
+                camera_mv_bx4x4,
+                mesh_v_nx3.unsqueeze(dim=0),
+                resolution=resolution,
+                device=self.device,
+                hierarchical_mask=hierarchical_mask
+            )
+
+            return_value['tex_pos'] = tex_pos
+            return_value['mask'] = mask
+            return_value['hard_mask'] = hard_mask
+            return_value['rast'] = rast
+            return_value['v_pos_clip'] = v_pos_clip
+            return_value['mask_pyramid'] = mask_pyramid
+            return_value['depth'] = depth
+        else:
+            raise NotImplementedError
+
+        return return_value
+
+    def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256):
+        # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1
+        v_list = []
+        f_list = []
+        n_batch = v_deformed_bxnx3.shape[0]
+        all_render_output = []
+        for i_batch in range(n_batch):
+            verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch])
+            v_list.append(verts_nx3)
+            f_list.append(faces_fx3)
+            render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution)
+            all_render_output.append(render_output)
+
+        # Concatenate all render output
+        return_keys = all_render_output[0].keys()
+        return_value = dict()
+        for k in return_keys:
+            value = [v[k] for v in all_render_output]
+            return_value[k] = value
+            # We can do concatenation outside of the render
+        return return_value
diff --git a/util/renderer.py b/util/renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a37627cb7782777c21eb26bf8aa9909c0da9ca25
--- /dev/null
+++ b/util/renderer.py
@@ -0,0 +1,49 @@
+
+import torch
+import torch.nn as nn
+import nvdiffrast.torch as dr
+from util.flexicubes_geometry import FlexiCubesGeometry
+
+class Renderer(nn.Module):
+    def __init__(self, tet_grid_size, camera_angle_num, scale, geo_type):
+        super().__init__()
+
+        self.tet_grid_size = tet_grid_size
+        self.camera_angle_num = camera_angle_num
+        self.scale = scale
+        self.geo_type = geo_type
+        self.glctx = dr.RasterizeCudaContext()
+
+        if self.geo_type == "flex":
+            self.flexicubes = FlexiCubesGeometry(grid_res = self.tet_grid_size)   
+
+    def forward(self, data, sdf, deform, verts, tets, training=False, weight = None):
+
+        results = {}
+
+        deform = torch.tanh(deform) / self.tet_grid_size * self.scale / 0.95
+        if self.geo_type == "flex":
+            deform = deform *0.5
+
+            v_deformed = verts + deform
+
+            verts_list = []
+            faces_list = []
+            reg_list = []
+            n_shape = verts.shape[0]
+            for i in range(n_shape): 
+                verts_i, faces_i, reg_i = self.flexicubes.get_mesh(v_deformed[i], sdf[i].squeeze(dim=-1),
+                with_uv=False, indices=tets, weight_n=weight[i], is_training=training)
+
+                verts_list.append(verts_i)
+                faces_list.append(faces_i)
+                reg_list.append(reg_i)       
+            verts = verts_list
+            faces = faces_list
+
+            flexicubes_surface_reg = torch.cat(reg_list).mean()
+            flexicubes_weight_reg = (weight ** 2).mean()
+            results["flex_surf_loss"] = flexicubes_surface_reg
+            results["flex_weight_loss"] = flexicubes_weight_reg
+
+        return results, verts, faces
\ No newline at end of file
diff --git a/util/tables.py b/util/tables.py
new file mode 100644
index 0000000000000000000000000000000000000000..5873e7727b5595a1e4fbc3bd10ae5be8f3d06cca
--- /dev/null
+++ b/util/tables.py
@@ -0,0 +1,791 @@
+# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
+#
+# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto.  Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
+dmc_table = [
+[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],
+[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]
+]
+num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,
+2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,
+1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,
+1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,
+2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,
+3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,
+2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,
+1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,
+1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
+1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]
+check_table = [
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 194],
+[1, -1, 0, 0, 193],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 164],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 161],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 152],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 145],
+[1, 0, 0, 1, 144],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 137],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 133],
+[1, 0, 1, 0, 132],
+[1, 1, 0, 0, 131],
+[1, 1, 0, 0, 130],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 100],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 98],
+[0, 0, 0, 0, 0],
+[1, 0, 0, 1, 96],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 88],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 82],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 74],
+[0, 0, 0, 0, 0],
+[1, 0, 1, 0, 72],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 70],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 67],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 65],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 56],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 52],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 44],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 1, 0, 0, 40],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 38],
+[1, 0, -1, 0, 37],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 33],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 28],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 26],
+[1, 0, 0, -1, 25],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, -1, 0, 0, 20],
+[0, 0, 0, 0, 0],
+[1, 0, -1, 0, 18],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 9],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[1, 0, 0, -1, 6],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0]
+]
+tet_table = [
+[-1, -1, -1, -1, -1, -1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, -1],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[0, 4, 0, 4, 4, -1],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[5, 5, 5, 5, 5, 5],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, -1, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, -1, 2, 4, 4, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 4, 4, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 4, 2, 4, 4, 2],
+[0, 4, 0, 4, 4, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 5, 2, 5, 5, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, -1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[4, 1, 1, 4, 4, 1],
+[0, 1, 1, 0, 0, 1],
+[4, 0, 0, 4, 4, 0],
+[2, 2, 2, 2, 2, 2],
+[-1, 1, 1, 4, 4, 1],
+[0, 1, 1, 4, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[5, 1, 1, 5, 5, 1],
+[0, 1, 1, 0, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[8, 8, 8, 8, 8, 8],
+[1, 1, 1, 4, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 4, 4, 1],
+[0, 4, 0, 4, 4, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 5, 5, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[6, 6, 6, 6, 6, 6],
+[6, -1, 0, 6, 0, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 1, 1, 6, 1, 6],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[6, 4, -1, 6, 4, 6],
+[6, 4, 0, 6, 4, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 1, 1, 6, 1, 6],
+[5, 5, 5, 5, 5, 5],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 4, 2, 2, 4, 2],
+[0, 4, 0, 4, 4, 0],
+[2, 0, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[6, 1, 1, 6, -1, 6],
+[6, 1, 1, 6, 0, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 2, 2, 6, 2, 6],
+[4, 1, 1, 4, 4, 1],
+[0, 1, 1, 0, 0, 1],
+[4, 0, 0, 4, 4, 4],
+[2, 2, 2, 2, 2, 2],
+[6, 1, 1, 6, 4, 6],
+[6, 1, 1, 6, 4, 6],
+[6, 0, 0, 6, 0, 6],
+[6, 2, 2, 6, 2, 6],
+[5, 1, 1, 5, 5, 1],
+[0, 1, 1, 0, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[6, 6, 6, 6, 6, 6],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 4, 1],
+[0, 4, 0, 4, 4, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 5, 0, 5, 0, 5],
+[5, 5, 5, 5, 5, 5],
+[5, 5, 5, 5, 5, 5],
+[0, 5, 0, 5, 0, 5],
+[-1, 5, 0, 5, 0, 5],
+[1, 5, 1, 5, 1, 5],
+[4, 5, -1, 5, 4, 5],
+[0, 5, 0, 5, 0, 5],
+[4, 5, 0, 5, 4, 5],
+[1, 5, 1, 5, 1, 5],
+[4, 4, 4, 4, 4, 4],
+[0, 4, 0, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[6, 6, 6, 6, 6, 6],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 5, 2, 5, -1, 5],
+[0, 5, 0, 5, 0, 5],
+[2, 5, 2, 5, 0, 5],
+[1, 5, 1, 5, 1, 5],
+[2, 5, 2, 5, 4, 5],
+[0, 5, 0, 5, 0, 5],
+[2, 5, 2, 5, 4, 5],
+[1, 5, 1, 5, 1, 5],
+[2, 4, 2, 4, 4, 2],
+[0, 4, 0, 4, 4, 4],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 6, 2, 6, 6, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 0, 2, 0, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, 1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[4, 1, 1, 1, 4, 1],
+[0, 1, 1, 1, 0, 1],
+[4, 0, 0, 4, 4, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[1, 1, 1, 1, 4, 1],
+[0, 0, 0, 0, 0, 0],
+[4, 0, 0, 4, 4, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[6, 0, 0, 6, 0, 6],
+[0, 0, 0, 0, 0, 0],
+[6, 6, 6, 6, 6, 6],
+[5, 5, 5, 5, 5, 5],
+[5, 5, 0, 5, 0, 5],
+[5, 5, 0, 5, 0, 5],
+[5, 5, 1, 5, 1, 5],
+[4, 4, 4, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 0, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[4, 4, 4, 4, 4, 4],
+[4, 4, 0, 4, 4, 4],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[8, 8, 8, 8, 8, 8],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 0, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 1, 1, 4, 4, 1],
+[2, 2, 2, 2, 2, 2],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 0, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 4, 2, 4, 4, 2],
+[1, 1, 1, 1, 1, 1],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[2, 2, 2, 2, 2, 2],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[5, 5, 5, 5, 5, 5],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[4, 4, 4, 4, 4, 4],
+[1, 1, 1, 1, 1, 1],
+[0, 0, 0, 0, 0, 0],
+[0, 0, 0, 0, 0, 0],
+[12, 12, 12, 12, 12, 12]
+]
diff --git a/util/utils.py b/util/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4cb35ef3712f11987c5a4432aee95395d93e6b7
--- /dev/null
+++ b/util/utils.py
@@ -0,0 +1,194 @@
+import numpy as np
+import torch
+import random
+
+
+# Reworked so this matches gluPerspective / glm::perspective, using fovy
+def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
+    # y = np.tan(fovy / 2)
+    x = np.tan(fovx / 2)
+    return torch.tensor([[1/x,         0,            0,              0],
+                         [  0, -aspect/x,            0,              0],
+                         [  0,         0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
+                         [  0,         0,           -1,              0]], dtype=torch.float32, device=device)
+
+
+def translate(x, y, z, device=None):
+    return torch.tensor([[1, 0, 0, x],
+                         [0, 1, 0, y],
+                         [0, 0, 1, z],
+                         [0, 0, 0, 1]], dtype=torch.float32, device=device)
+
+
+def rotate_x(a, device=None):
+    s, c = np.sin(a), np.cos(a)
+    return torch.tensor([[1, 0,  0, 0],
+                         [0, c, -s, 0],
+                         [0, s,  c, 0],
+                         [0, 0,  0, 1]], dtype=torch.float32, device=device)
+
+
+def rotate_y(a, device=None):
+    s, c = np.sin(a), np.cos(a)
+    return torch.tensor([[ c, 0, s, 0],
+                         [ 0, 1, 0, 0],
+                         [-s, 0, c, 0],
+                         [ 0, 0, 0, 1]], dtype=torch.float32, device=device)
+
+
+def rotate_z(a, device=None):
+    s, c = np.sin(a), np.cos(a)
+    return torch.tensor([[c, -s, 0, 0],
+                         [s,  c, 0, 0],
+                         [0,  0, 1, 0],
+                         [0,  0, 0, 1]], dtype=torch.float32, device=device)
+
+@torch.no_grad()
+def batch_random_rotation_translation(b, t, device=None):
+    m = np.random.normal(size=[b, 3, 3])
+    m[:, 1] = np.cross(m[:, 0], m[:, 2])
+    m[:, 2] = np.cross(m[:, 0], m[:, 1])
+    m = m / np.linalg.norm(m, axis=2, keepdims=True)
+    m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant')
+    m[:, 3, 3] = 1.0
+    m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3])
+    return torch.tensor(m, dtype=torch.float32, device=device)
+
+@torch.no_grad()
+def random_rotation_translation(t, device=None):
+    m = np.random.normal(size=[3, 3])
+    m[1] = np.cross(m[0], m[2])
+    m[2] = np.cross(m[0], m[1])
+    m = m / np.linalg.norm(m, axis=1, keepdims=True)
+    m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
+    m[3, 3] = 1.0
+    m[:3, 3] = np.random.uniform(-t, t, size=[3])
+    return torch.tensor(m, dtype=torch.float32, device=device)
+
+
+@torch.no_grad()
+def random_rotation(device=None):
+    m = np.random.normal(size=[3, 3])
+    m[1] = np.cross(m[0], m[2])
+    m[2] = np.cross(m[0], m[1])
+    m = m / np.linalg.norm(m, axis=1, keepdims=True)
+    m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
+    m[3, 3] = 1.0
+    m[:3, 3] = np.array([0,0,0]).astype(np.float32)
+    return torch.tensor(m, dtype=torch.float32, device=device)
+
+
+def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+    return torch.sum(x*y, -1, keepdim=True)
+
+
+def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
+    return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
+
+
+def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
+    return x / length(x, eps)
+
+
+def lr_schedule(iter, warmup_iter, scheduler_decay):
+    if iter < warmup_iter:
+        return iter / warmup_iter
+    return max(0.0, 10 ** (
+            -(iter - warmup_iter) * scheduler_decay)) 
+
+
+def trans_depth(depth):
+    depth = depth[0].detach().cpu().numpy()
+    valid = depth > 0
+    depth[valid] -= depth[valid].min()
+    depth[valid] = ((depth[valid] / depth[valid].max()) * 255)
+    return depth.astype('uint8')
+
+
+def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None):
+    assert isinstance(input, torch.Tensor)
+    if posinf is None:
+        posinf = torch.finfo(input.dtype).max
+    if neginf is None:
+        neginf = torch.finfo(input.dtype).min
+    assert nan == 0
+    return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+
+def load_item(filepath):
+    with open(filepath, 'r') as f:
+        items = [name.strip() for name in f.readlines()]
+    return set(items)
+
+def load_prompt(filepath):
+    uuid2prompt = {}
+    with open(filepath, 'r') as f:
+        for line in f.readlines():
+            list_line = line.split(',')
+            uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip()
+    return uuid2prompt
+
+def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0):
+    if scale == 1:
+        return image_tensor
+    B, C, H, W = image_tensor.shape
+    new_H, new_W = int(H * scale), int(W * scale)
+    resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0)
+    background = torch.zeros_like(image_tensor) + c
+    start_y, start_x = (H - new_H) // 2, (W - new_W) // 2
+    if shift == 0:
+        background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image
+    else:
+        for i in range(B):
+            randx = random.randint(-shift, shift)
+            randy = random.randint(-shift, shift)   
+            if rgb == True:
+                if i == 0 or i==2 or i==4:
+                    randx = 0
+                    randy = 0 
+            background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i]
+    if aug_shift == 0:
+        return background  
+    for i in range(B):
+        for j in range(C):
+            background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255
+    return background 
+                               
+def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0):
+    # triview_color: [6,C,H,W]
+    # rgb is useful when shift is not 0
+    triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift)
+    if blender is False:
+        triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2])
+        triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1)
+        triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2)
+        triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2)
+        triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1)
+        triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2)
+    else:
+        triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2])
+        triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1)
+        triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2])
+        triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2])
+        triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2)
+        triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2)
+        if fix == True:
+            triview_color0[1] = triview_color0[1] * 0
+            triview_color0[2] = triview_color0[2] * 0
+            triview_color3[1] = triview_color3[1] * 0
+            triview_color3[2] = triview_color3[2] * 0
+
+            triview_color1[0] = triview_color1[0] * 0
+            triview_color1[1] = triview_color1[1] * 0
+            triview_color4[0] = triview_color4[0] * 0
+            triview_color4[1] = triview_color4[1] * 0
+
+            triview_color2[0] = triview_color2[0] * 0
+            triview_color2[2] = triview_color2[2] * 0
+            triview_color5[0] = triview_color5[0] * 0
+            triview_color5[2] = triview_color5[2] * 0
+    color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2)
+    color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2)
+    color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim)
+    return color_tensor_gt
+