xunsong.li commited on
Commit
7ccc423
·
0 Parent(s):

init commit

Browse files
Files changed (46) hide show
  1. .gitignore +4 -0
  2. LICENSE +203 -0
  3. NOTICE +20 -0
  4. README.md +142 -0
  5. app.py +263 -0
  6. assets/mini_program_maliang.png +0 -0
  7. configs/inference/inference_v1.yaml +23 -0
  8. configs/inference/inference_v2.yaml +35 -0
  9. configs/inference/pose_videos/anyone-video-1_kps.mp4 +0 -0
  10. configs/inference/pose_videos/anyone-video-2_kps.mp4 +0 -0
  11. configs/inference/pose_videos/anyone-video-4_kps.mp4 +0 -0
  12. configs/inference/pose_videos/anyone-video-5_kps.mp4 +0 -0
  13. configs/inference/ref_images/anyone-1.png +0 -0
  14. configs/inference/ref_images/anyone-10.png +0 -0
  15. configs/inference/ref_images/anyone-11.png +0 -0
  16. configs/inference/ref_images/anyone-2.png +0 -0
  17. configs/inference/ref_images/anyone-3.png +0 -0
  18. configs/inference/ref_images/anyone-5.png +0 -0
  19. configs/prompts/animation.yaml +26 -0
  20. configs/prompts/test_cases.py +33 -0
  21. requirements.txt +28 -0
  22. scripts/pose2vid.py +167 -0
  23. src/__init__.py +0 -0
  24. src/dwpose/__init__.py +123 -0
  25. src/dwpose/onnxdet.py +130 -0
  26. src/dwpose/onnxpose.py +370 -0
  27. src/dwpose/util.py +378 -0
  28. src/dwpose/wholebody.py +48 -0
  29. src/models/attention.py +443 -0
  30. src/models/motion_module.py +388 -0
  31. src/models/mutual_self_attention.py +363 -0
  32. src/models/pose_guider.py +57 -0
  33. src/models/resnet.py +252 -0
  34. src/models/transformer_2d.py +396 -0
  35. src/models/transformer_3d.py +169 -0
  36. src/models/unet_2d_blocks.py +1074 -0
  37. src/models/unet_2d_condition.py +1308 -0
  38. src/models/unet_3d.py +668 -0
  39. src/models/unet_3d_blocks.py +862 -0
  40. src/pipelines/__init__.py +0 -0
  41. src/pipelines/context.py +76 -0
  42. src/pipelines/pipeline_pose2vid.py +454 -0
  43. src/pipelines/pipeline_pose2vid_long.py +571 -0
  44. src/pipelines/utils.py +29 -0
  45. src/utils/util.py +111 -0
  46. tools/vid2pose.py +38 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ pretrained_weights/
3
+ output/
4
+ .venv/
LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright @2023-2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
NOTICE ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==============================================================
2
+ This repo also contains various third-party components and some code modified from other repos under other open source licenses. The following sections contain licensing infromation for such third-party libraries.
3
+
4
+ -----------------------------
5
+ majic-animate
6
+ BSD 3-Clause License
7
+ Copyright (c) Bytedance Inc.
8
+
9
+ -----------------------------
10
+ animatediff
11
+ Apache License, Version 2.0
12
+
13
+ -----------------------------
14
+ Dwpose
15
+ Apache License, Version 2.0
16
+
17
+ -----------------------------
18
+ inference pipeline for animatediff-cli-prompt-travel
19
+ animatediff-cli-prompt-travel
20
+ Apache License, Version 2.0
README.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🤗 Introduction
2
+
3
+ This repository reproduces [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone). To align the results demonstrated by the original paper, we adopt various approaches and tricks, which may differ somewhat from the paper and another [implementation](https://github.com/guoqincode/Open-AnimateAnyone).
4
+
5
+ It's worth noting that this is a very preliminary version, aiming for approximating the performance (roughly 80% under our test) showed in [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone).
6
+
7
+ We will continue to develop it, and also welcome feedbacks and ideas from the community. The enhanced version will also be launched on our [MoBi MaLiang](https://maliang.mthreads.com/) AIGC platform, running on our own full-featured GPU S4000 cloud computing platform.
8
+
9
+ # 📝 Release Plans
10
+
11
+ - [x] Inference codes and pretrained weights
12
+ - [ ] Training scripts
13
+
14
+ **Note** The training code involves private data and packages. We will organize this portion of the code as soon as possible and then release it.
15
+
16
+ # 🎞️ Examples
17
+
18
+ Here are some results we generated, with the resolution of 512x768.
19
+
20
+ https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/f0454f30-6726-4ad4-80a7-5b7a15619057
21
+
22
+ https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/337ff231-68a3-4760-a9f9-5113654acf48
23
+
24
+ <table class="center">
25
+
26
+ <tr>
27
+ <td width=50% style="border: none">
28
+ <video controls autoplay loop src="https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/9c4d852e-0a99-4607-8d63-569a1f67a8d2" muted="false"></video>
29
+ </td>
30
+ <td width=50% style="border: none">
31
+ <video controls autoplay loop src="https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/722c6535-2901-4e23-9de9-501b22306ebd" muted="false"></video>
32
+ </td>
33
+ </tr>
34
+
35
+ <tr>
36
+ <td width=50% style="border: none">
37
+ <video controls autoplay loop src="https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/17b907cc-c97e-43cd-af18-b646393c8e8a" muted="false"></video>
38
+ </td>
39
+ <td width=50% style="border: none">
40
+ <video controls autoplay loop src="https://github.com/MooreThreads/Moore-AnimateAnyone/assets/138439222/86f2f6d2-df60-4333-b19b-4c5abcd5999d" muted="false"></video>
41
+ </td>
42
+ </tr>
43
+ </table>
44
+
45
+ **Limitation**: We observe following shortcomings in current version:
46
+ 1. The background may occur some artifacts, when the reference image has a clean background
47
+ 2. Suboptimal results may arise when there is a scale mismatch between the reference image and keypoints. We have yet to implement preprocessing techniques as mentioned in the [paper](https://arxiv.org/pdf/2311.17117.pdf).
48
+ 3. Some flickering and jittering may occur when the motion sequence is subtle or the scene is static.
49
+
50
+ These issues will be addressed and improved in the near future. We appreciate your anticipation!
51
+
52
+ # ⚒️ Installation
53
+
54
+ ## Build Environtment
55
+
56
+ We Recommend a python version `>=3.10` and cuda version `=11.7`. Then build environment as follows:
57
+
58
+ ```shell
59
+ # [Optional] Create a virtual env
60
+ python -m venv .venv
61
+ source .venv/bin/activate
62
+ # Install with pip:
63
+ pip install -r requirements.txt
64
+ ```
65
+
66
+ ## Download weights
67
+
68
+ Download our trained [weights](https://huggingface.co/patrolli/AnimateAnyone/tree/main), which include four parts: `denoising_unet.pth`, `reference_unet.pth`, `pose_guider.pth` and `motion_module.pth`.
69
+
70
+ Download pretrained weight of based models and other components:
71
+ - [StableDiffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
72
+ - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
73
+ - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder)
74
+
75
+ Download dwpose weights (`dw-ll_ucoco_384.onnx`, `yolox_l.onnx`) following [this](https://github.com/IDEA-Research/DWPose?tab=readme-ov-file#-dwpose-for-controlnet).
76
+
77
+ Put these weights under a directory, like `./pretrained_weights`, and orgnize them as follows:
78
+
79
+ ```text
80
+ ./pretrained_weights/
81
+ |-- DWPose
82
+ | |-- dw-ll_ucoco_384.onnx
83
+ | `-- yolox_l.onnx
84
+ |-- image_encoder
85
+ | |-- config.json
86
+ | `-- pytorch_model.bin
87
+ |-- denoising_unet.pth
88
+ |-- motion_module.pth
89
+ |-- pose_guider.pth
90
+ |-- reference_unet.pth
91
+ |-- sd-vae-ft-mse
92
+ | |-- config.json
93
+ | |-- diffusion_pytorch_model.bin
94
+ | `-- diffusion_pytorch_model.safetensors
95
+ `-- stable-diffusion-v1-5
96
+ |-- feature_extractor
97
+ | `-- preprocessor_config.json
98
+ |-- model_index.json
99
+ |-- unet
100
+ | |-- config.json
101
+ | `-- diffusion_pytorch_model.bin
102
+ `-- v1-inference.yaml
103
+ ```
104
+
105
+ Note: If you have installed some of the pretrained models, such as `StableDiffusion V1.5`, you can specify their paths in the config file (e.g. `./config/prompts/animation.yaml`).
106
+
107
+ # 🚀 Inference
108
+
109
+ Here is the cli command for running inference scripts:
110
+
111
+ ```shell
112
+ python -m scripts.pose2vid --config ./configs/prompts/animation.yaml -W 512 -H 784 -L 64
113
+ ```
114
+
115
+ You can refer the format of `animation.yaml` to add your own reference images or pose videos. To convert the raw video into a pose video (keypoint sequence), you can run with the following command:
116
+
117
+ ```shell
118
+ python tools/vid2pose.py --video_path /path/to/your/video.mp4
119
+ ```
120
+
121
+ # 🎨 Gradio Demo
122
+
123
+ You can run a local gradio app via following commands:
124
+
125
+ `python app.py`
126
+
127
+ # 🖌️ Try on Mobi MaLiang
128
+
129
+ We will launched this model on our [MoBi MaLiang](https://maliang.mthreads.com/) AIGC platform, running on our own full-featured GPU S4000 cloud computing platform. Mobi MaLiang has now integrated various AIGC applications and functionalities (e.g. text-to-image, controllable generation...). You can experience it by [clicking this link](https://maliang.mthreads.com/) or scanning the QR code bellow via WeChat!
130
+
131
+ <p align="left">
132
+ <img src="assets/mini_program_maliang.png" width="100
133
+ "/>
134
+ </p>
135
+
136
+ # ⚖️ Disclaimer
137
+
138
+ This project is intended for academic research, and we explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using the generative model. The project contributors have no legal affiliation with, nor accountability for, users' behaviors. It is imperative to use the generative model responsibly, adhering to both ethical and legal standards.
139
+
140
+ # 🙏🏻 Acknowledgements
141
+
142
+ We first thank the authors of [AnimateAnyone](). Additionally, we would like to thank the contributors to the [majic-animate](https://github.com/magic-research/magic-animate), [animatediff](https://github.com/guoyww/AnimateDiff) and [Open-AnimateAnyone](https://github.com/guoqincode/Open-AnimateAnyone) repositorities, for their open research and exploration. Furthermore, our repo incorporates some codes from [dwpose](https://github.com/IDEA-Research/DWPose) and [animatediff-cli-prompt-travel](https://github.com/s9roll7/animatediff-cli-prompt-travel/), and we extend our thanks to them as well.
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from datetime import datetime
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import AutoencoderKL, DDIMScheduler
9
+ from einops import repeat
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from transformers import CLIPVisionModelWithProjection
14
+
15
+ from src.models.pose_guider import PoseGuider
16
+ from src.models.unet_2d_condition import UNet2DConditionModel
17
+ from src.models.unet_3d import UNet3DConditionModel
18
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
19
+ from src.utils.util import get_fps, read_frames, save_videos_grid
20
+
21
+
22
+ class AnimateController:
23
+ def __init__(
24
+ self,
25
+ config_path="./configs/prompts/animation.yaml",
26
+ weight_dtype=torch.float16,
27
+ ):
28
+ # Read pretrained weights path from config
29
+ self.config = OmegaConf.load(config_path)
30
+ self.pipeline = None
31
+ self.weight_dtype = weight_dtype
32
+
33
+ def animate(
34
+ self,
35
+ ref_image,
36
+ pose_video_path,
37
+ width=512,
38
+ height=768,
39
+ length=24,
40
+ num_inference_steps=25,
41
+ cfg=3.5,
42
+ seed=123,
43
+ ):
44
+ generator = torch.manual_seed(seed)
45
+ if isinstance(ref_image, np.ndarray):
46
+ ref_image = Image.fromarray(ref_image)
47
+ if self.pipeline is None:
48
+ vae = AutoencoderKL.from_pretrained(
49
+ self.config.pretrained_vae_path,
50
+ ).to("cuda", dtype=self.weight_dtype)
51
+
52
+ reference_unet = UNet2DConditionModel.from_pretrained(
53
+ self.config.pretrained_base_model_path,
54
+ subfolder="unet",
55
+ ).to(dtype=self.weight_dtype, device="cuda")
56
+
57
+ inference_config_path = self.config.inference_config
58
+ infer_config = OmegaConf.load(inference_config_path)
59
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
60
+ self.config.pretrained_base_model_path,
61
+ self.config.motion_module_path,
62
+ subfolder="unet",
63
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
64
+ ).to(dtype=self.weight_dtype, device="cuda")
65
+
66
+ pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
67
+ dtype=self.weight_dtype, device="cuda"
68
+ )
69
+
70
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
71
+ self.config.image_encoder_path
72
+ ).to(dtype=self.weight_dtype, device="cuda")
73
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
74
+ scheduler = DDIMScheduler(**sched_kwargs)
75
+
76
+ # load pretrained weights
77
+ denoising_unet.load_state_dict(
78
+ torch.load(self.config.denoising_unet_path, map_location="cpu"),
79
+ strict=False,
80
+ )
81
+ reference_unet.load_state_dict(
82
+ torch.load(self.config.reference_unet_path, map_location="cpu"),
83
+ )
84
+ pose_guider.load_state_dict(
85
+ torch.load(self.config.pose_guider_path, map_location="cpu"),
86
+ )
87
+
88
+ pipe = Pose2VideoPipeline(
89
+ vae=vae,
90
+ image_encoder=image_enc,
91
+ reference_unet=reference_unet,
92
+ denoising_unet=denoising_unet,
93
+ pose_guider=pose_guider,
94
+ scheduler=scheduler,
95
+ )
96
+ pipe = pipe.to("cuda", dtype=self.weight_dtype)
97
+ self.pipeline = pipe
98
+
99
+ pose_images = read_frames(pose_video_path)
100
+ src_fps = get_fps(pose_video_path)
101
+
102
+ pose_list = []
103
+ pose_tensor_list = []
104
+ pose_transform = transforms.Compose(
105
+ [transforms.Resize((height, width)), transforms.ToTensor()]
106
+ )
107
+ for pose_image_pil in pose_images[:length]:
108
+ pose_list.append(pose_image_pil)
109
+ pose_tensor_list.append(pose_transform(pose_image_pil))
110
+
111
+ video = self.pipeline(
112
+ ref_image,
113
+ pose_list,
114
+ width=width,
115
+ height=height,
116
+ video_length=length,
117
+ num_inference_steps=num_inference_steps,
118
+ guidance_scale=cfg,
119
+ generator=generator,
120
+ ).videos
121
+
122
+ ref_image_tensor = pose_transform(ref_image) # (c, h, w)
123
+ ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
124
+ ref_image_tensor = repeat(
125
+ ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=length
126
+ )
127
+ pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
128
+ pose_tensor = pose_tensor.transpose(0, 1)
129
+ pose_tensor = pose_tensor.unsqueeze(0)
130
+ video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
131
+
132
+ save_dir = f"./output/gradio"
133
+ if not os.path.exists(save_dir):
134
+ os.makedirs(save_dir, exist_ok=True)
135
+ date_str = datetime.now().strftime("%Y%m%d")
136
+ time_str = datetime.now().strftime("%H%M")
137
+ out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4")
138
+ save_videos_grid(
139
+ video,
140
+ out_path,
141
+ n_rows=3,
142
+ fps=src_fps,
143
+ )
144
+
145
+ torch.cuda.empty_cache()
146
+
147
+ return out_path
148
+
149
+
150
+ controller = AnimateController()
151
+
152
+
153
+ def ui():
154
+ with gr.Blocks() as demo:
155
+ gr.Markdown(
156
+ """
157
+ # Moore-AnimateAnyone Demo
158
+ """
159
+ )
160
+ animation = gr.Video(
161
+ format="mp4",
162
+ label="Animation Results",
163
+ height=448,
164
+ autoplay=True,
165
+ )
166
+
167
+ with gr.Row():
168
+ reference_image = gr.Image(label="Reference Image")
169
+ motion_sequence = gr.Video(
170
+ format="mp4", label="Motion Sequence", height=512
171
+ )
172
+
173
+ with gr.Column():
174
+ width_slider = gr.Slider(
175
+ label="Width", minimum=448, maximum=768, value=512, step=64
176
+ )
177
+ height_slider = gr.Slider(
178
+ label="Height", minimum=512, maximum=1024, value=768, step=64
179
+ )
180
+ length_slider = gr.Slider(
181
+ label="Video Length", minimum=24, maximum=128, value=24, step=24
182
+ )
183
+ with gr.Row():
184
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
185
+ seed_button = gr.Button(
186
+ value="\U0001F3B2", elem_classes="toolbutton"
187
+ )
188
+ seed_button.click(
189
+ fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)),
190
+ inputs=[],
191
+ outputs=[seed_textbox],
192
+ )
193
+ with gr.Row():
194
+ sampling_steps = gr.Slider(
195
+ label="Sampling steps",
196
+ value=25,
197
+ info="default: 25",
198
+ step=5,
199
+ maximum=30,
200
+ minimum=10,
201
+ )
202
+ guidance_scale = gr.Slider(
203
+ label="Guidance scale",
204
+ value=3.5,
205
+ info="default: 3.5",
206
+ step=0.5,
207
+ maximum=10,
208
+ minimum=2.0,
209
+ )
210
+ submit = gr.Button("Animate")
211
+
212
+ def read_video(video):
213
+ return video
214
+
215
+ def read_image(image):
216
+ return Image.fromarray(image)
217
+
218
+ # when user uploads a new video
219
+ motion_sequence.upload(read_video, motion_sequence, motion_sequence)
220
+ # when `first_frame` is updated
221
+ reference_image.upload(read_image, reference_image, reference_image)
222
+ # when the `submit` button is clicked
223
+ submit.click(
224
+ controller.animate,
225
+ [
226
+ reference_image,
227
+ motion_sequence,
228
+ width_slider,
229
+ height_slider,
230
+ length_slider,
231
+ sampling_steps,
232
+ guidance_scale,
233
+ seed_textbox,
234
+ ],
235
+ animation,
236
+ )
237
+
238
+ # Examples
239
+ gr.Markdown("## Examples")
240
+ gr.Examples(
241
+ examples=[
242
+ [
243
+ "./configs/inference/ref_images/anyone-5.png",
244
+ "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
245
+ ],
246
+ [
247
+ "./configs/inference/ref_images/anyone-10.png",
248
+ "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
249
+ ],
250
+ [
251
+ "./configs/inference/ref_images/anyone-2.png",
252
+ "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
253
+ ],
254
+ ],
255
+ inputs=[reference_image, motion_sequence],
256
+ outputs=animation,
257
+ )
258
+
259
+ return demo
260
+
261
+
262
+ demo = ui()
263
+ demo.launch(share=True)
assets/mini_program_maliang.png ADDED
configs/inference/inference_v1.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ unet_use_cross_frame_attention: false
3
+ unet_use_temporal_attention: false
4
+ use_motion_module: true
5
+ motion_module_resolutions: [1,2,4,8]
6
+ motion_module_mid_block: false
7
+ motion_module_decoder_only: false
8
+ motion_module_type: "Vanilla"
9
+
10
+ motion_module_kwargs:
11
+ num_attention_heads: 8
12
+ num_transformer_block: 1
13
+ attention_block_types: [ "Temporal_Self", "Temporal_Self" ]
14
+ temporal_position_encoding: true
15
+ temporal_position_encoding_max_len: 24
16
+ temporal_attention_dim_div: 1
17
+
18
+ noise_scheduler_kwargs:
19
+ beta_start: 0.00085
20
+ beta_end: 0.012
21
+ beta_schedule: "linear"
22
+ steps_offset: 1
23
+ clip_sample: False
configs/inference/inference_v2.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions:
7
+ - 1
8
+ - 2
9
+ - 4
10
+ - 8
11
+ motion_module_mid_block: true
12
+ motion_module_decoder_only: false
13
+ motion_module_type: Vanilla
14
+ motion_module_kwargs:
15
+ num_attention_heads: 8
16
+ num_transformer_block: 1
17
+ attention_block_types:
18
+ - Temporal_Self
19
+ - Temporal_Self
20
+ temporal_position_encoding: true
21
+ temporal_position_encoding_max_len: 32
22
+ temporal_attention_dim_div: 1
23
+
24
+ noise_scheduler_kwargs:
25
+ beta_start: 0.00085
26
+ beta_end: 0.012
27
+ beta_schedule: "linear"
28
+ clip_sample: false
29
+ steps_offset: 1
30
+ ### Zero-SNR params
31
+ prediction_type: "v_prediction"
32
+ rescale_betas_zero_snr: True
33
+ timestep_spacing: "trailing"
34
+
35
+ sampler: DDIM
configs/inference/pose_videos/anyone-video-1_kps.mp4 ADDED
Binary file (755 kB). View file
 
configs/inference/pose_videos/anyone-video-2_kps.mp4 ADDED
Binary file (520 kB). View file
 
configs/inference/pose_videos/anyone-video-4_kps.mp4 ADDED
Binary file (974 kB). View file
 
configs/inference/pose_videos/anyone-video-5_kps.mp4 ADDED
Binary file (674 kB). View file
 
configs/inference/ref_images/anyone-1.png ADDED
configs/inference/ref_images/anyone-10.png ADDED
configs/inference/ref_images/anyone-11.png ADDED
configs/inference/ref_images/anyone-2.png ADDED
configs/inference/ref_images/anyone-3.png ADDED
configs/inference/ref_images/anyone-5.png ADDED
configs/prompts/animation.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_base_model_path: "./pretrained_weights/stable-diffusion-v1-5/"
2
+ pretrained_vae_path: "./pretrained_weights/sd-vae-ft-mse"
3
+ image_encoder_path: "./pretrained_weights/image_encoder"
4
+ denoising_unet_path: "./pretrained_weights/denoising_unet.pth"
5
+ reference_unet_path: "./pretrained_weights/reference_unet.pth"
6
+ pose_guider_path: "./pretrained_weights/pose_guider.pth"
7
+ motion_module_path: "./pretrained_weights/motion_module.pth"
8
+
9
+ inference_config: "./configs/inference/inference_v2.yaml"
10
+ weight_dtype: 'fp16'
11
+
12
+ test_cases:
13
+ "./configs/inference/ref_images/anyone-2.png":
14
+ - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
15
+ - "./configs/inference/pose_videos/anyone-video-5_kps.mp4"
16
+ "./configs/inference/ref_images/anyone-10.png":
17
+ - "./configs/inference/pose_videos/anyone-video-1_kps.mp4"
18
+ - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
19
+ "./configs/inference/ref_images/anyone-11.png":
20
+ - "./configs/inference/pose_videos/anyone-video-1_kps.mp4"
21
+ - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
22
+ "./configs/inference/ref_images/anyone-3.png":
23
+ - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
24
+ - "./configs/inference/pose_videos/anyone-video-5_kps.mp4"
25
+ "./configs/inference/ref_images/anyone-5.png":
26
+ - "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
configs/prompts/test_cases.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TestCasesDict = {
2
+ 0: [
3
+ {
4
+ "./configs/inference/ref_images/anyone-2.png": [
5
+ "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
6
+ "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
7
+ ]
8
+ },
9
+ {
10
+ "./configs/inference/ref_images/anyone-10.png": [
11
+ "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
12
+ "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
13
+ ]
14
+ },
15
+ {
16
+ "./configs/inference/ref_images/anyone-11.png": [
17
+ "./configs/inference/pose_videos/anyone-video-1_kps.mp4",
18
+ "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
19
+ ]
20
+ },
21
+ {
22
+ "./configs/inference/anyone-ref-3.png": [
23
+ "./configs/inference/pose_videos/anyone-video-2_kps.mp4",
24
+ "./configs/inference/pose_videos/anyone-video-5_kps.mp4",
25
+ ]
26
+ },
27
+ {
28
+ "./configs/inference/ref_images/anyone-5.png": [
29
+ "./configs/inference/pose_videos/anyone-video-2_kps.mp4"
30
+ ]
31
+ },
32
+ ],
33
+ }
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ av==11.0.0
3
+ clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
4
+ decord==0.6.0
5
+ diffusers==0.24.0
6
+ einops==0.4.1
7
+ gradio==3.41.2
8
+ gradio_client==0.5.0
9
+ imageio==2.33.0
10
+ imageio-ffmpeg==0.4.9
11
+ numpy==1.23.5
12
+ omegaconf==2.2.3
13
+ onnxruntime==1.16.3
14
+ onnxruntime-gpu==1.16.3
15
+ open-clip-torch==2.20.0
16
+ opencv-contrib-python==4.8.1.78
17
+ opencv-python==4.8.1.78
18
+ Pillow==9.5.0
19
+ scikit-image==0.21.0
20
+ scikit-learn==1.3.2
21
+ scipy==1.11.4
22
+ torch==2.0.1
23
+ torchdiffeq==0.2.3
24
+ torchmetrics==1.2.1
25
+ torchsde==0.2.5
26
+ torchvision==0.15.2
27
+ tqdm==4.66.1
28
+ transformers==4.30.2
scripts/pose2vid.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from typing import List
6
+
7
+ import av
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ from diffusers import AutoencoderKL, DDIMScheduler
12
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
13
+ from einops import repeat
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPVisionModelWithProjection
18
+
19
+ from configs.prompts.test_cases import TestCasesDict
20
+ from src.models.pose_guider import PoseGuider
21
+ from src.models.unet_2d_condition import UNet2DConditionModel
22
+ from src.models.unet_3d import UNet3DConditionModel
23
+ from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
24
+ from src.utils.util import get_fps, read_frames, save_videos_grid
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--config")
30
+ parser.add_argument("-W", type=int, default=512)
31
+ parser.add_argument("-H", type=int, default=784)
32
+ parser.add_argument("-L", type=int, default=24)
33
+ parser.add_argument("--seed", type=int, default=42)
34
+ parser.add_argument("--cfg", type=float, default=3.5)
35
+ parser.add_argument("--steps", type=int, default=30)
36
+ parser.add_argument("--fps", type=int)
37
+ args = parser.parse_args()
38
+
39
+ return args
40
+
41
+
42
+ def main():
43
+ args = parse_args()
44
+
45
+ config = OmegaConf.load(args.config)
46
+
47
+ if config.weight_dtype == "fp16":
48
+ weight_dtype = torch.float16
49
+ else:
50
+ weight_dtype = torch.float32
51
+
52
+ vae = AutoencoderKL.from_pretrained(
53
+ config.pretrained_vae_path,
54
+ ).to("cuda", dtype=weight_dtype)
55
+
56
+ reference_unet = UNet2DConditionModel.from_pretrained(
57
+ config.pretrained_base_model_path,
58
+ subfolder="unet",
59
+ ).to(dtype=weight_dtype, device="cuda")
60
+
61
+ inference_config_path = config.inference_config
62
+ infer_config = OmegaConf.load(inference_config_path)
63
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
64
+ config.pretrained_base_model_path,
65
+ config.motion_module_path,
66
+ subfolder="unet",
67
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
68
+ ).to(dtype=weight_dtype, device="cuda")
69
+
70
+ pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
71
+ dtype=weight_dtype, device="cuda"
72
+ )
73
+
74
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
75
+ config.image_encoder_path
76
+ ).to(dtype=weight_dtype, device="cuda")
77
+
78
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
79
+ scheduler = DDIMScheduler(**sched_kwargs)
80
+
81
+ generator = torch.manual_seed(args.seed)
82
+
83
+ width, height = args.W, args.H
84
+
85
+ # load pretrained weights
86
+ denoising_unet.load_state_dict(
87
+ torch.load(config.denoising_unet_path, map_location="cpu"),
88
+ strict=False,
89
+ )
90
+ reference_unet.load_state_dict(
91
+ torch.load(config.reference_unet_path, map_location="cpu"),
92
+ )
93
+ pose_guider.load_state_dict(
94
+ torch.load(config.pose_guider_path, map_location="cpu"),
95
+ )
96
+
97
+ pipe = Pose2VideoPipeline(
98
+ vae=vae,
99
+ image_encoder=image_enc,
100
+ reference_unet=reference_unet,
101
+ denoising_unet=denoising_unet,
102
+ pose_guider=pose_guider,
103
+ scheduler=scheduler,
104
+ )
105
+ pipe = pipe.to("cuda", dtype=weight_dtype)
106
+
107
+ date_str = datetime.now().strftime("%Y%m%d")
108
+ time_str = datetime.now().strftime("%H%M")
109
+ save_dir_name = f"{time_str}--seed_{args.seed}-{args.W}x{args.H}"
110
+
111
+ save_dir = Path(f"output/{date_str}/{save_dir_name}")
112
+ save_dir.mkdir(exist_ok=True, parents=True)
113
+
114
+ for ref_image_path in config["test_cases"].keys():
115
+ # Each ref_image may correspond to multiple actions
116
+ for pose_video_path in config["test_cases"][ref_image_path]:
117
+ ref_name = Path(ref_image_path).stem
118
+ pose_name = Path(pose_video_path).stem.replace("_kps", "")
119
+
120
+ ref_image_pil = Image.open(ref_image_path).convert("RGB")
121
+
122
+ pose_list = []
123
+ pose_tensor_list = []
124
+ pose_images = read_frames(pose_video_path)
125
+ src_fps = get_fps(pose_video_path)
126
+ print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
127
+ pose_transform = transforms.Compose(
128
+ [transforms.Resize((height, width)), transforms.ToTensor()]
129
+ )
130
+ for pose_image_pil in pose_images[: args.L]:
131
+ pose_tensor_list.append(pose_transform(pose_image_pil))
132
+ pose_list.append(pose_image_pil)
133
+
134
+ ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
135
+ ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(
136
+ 0
137
+ ) # (1, c, 1, h, w)
138
+ ref_image_tensor = repeat(
139
+ ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=args.L
140
+ )
141
+
142
+ pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
143
+ pose_tensor = pose_tensor.transpose(0, 1)
144
+ pose_tensor = pose_tensor.unsqueeze(0)
145
+
146
+ video = pipe(
147
+ ref_image_pil,
148
+ pose_list,
149
+ width,
150
+ height,
151
+ args.L,
152
+ args.steps,
153
+ args.cfg,
154
+ generator=generator,
155
+ ).videos
156
+
157
+ video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)
158
+ save_videos_grid(
159
+ video,
160
+ f"{save_dir}/{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}.mp4",
161
+ n_rows=3,
162
+ fps=src_fps if args.fps is None else args.fps,
163
+ )
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()
src/__init__.py ADDED
File without changes
src/dwpose/__init__.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ # Openpose
3
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
4
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
5
+ # 3rd Edited by ControlNet
6
+ # 4th Edited by ControlNet (added face and correct hands)
7
+
8
+ import copy
9
+ import os
10
+
11
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ from controlnet_aux.util import HWC3, resize_image
16
+ from PIL import Image
17
+
18
+ from . import util
19
+ from .wholebody import Wholebody
20
+
21
+
22
+ def draw_pose(pose, H, W):
23
+ bodies = pose["bodies"]
24
+ faces = pose["faces"]
25
+ hands = pose["hands"]
26
+ candidate = bodies["candidate"]
27
+ subset = bodies["subset"]
28
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
29
+
30
+ canvas = util.draw_bodypose(canvas, candidate, subset)
31
+
32
+ canvas = util.draw_handpose(canvas, hands)
33
+
34
+ canvas = util.draw_facepose(canvas, faces)
35
+
36
+ return canvas
37
+
38
+
39
+ class DWposeDetector:
40
+ def __init__(self):
41
+ pass
42
+
43
+ def to(self, device):
44
+ self.pose_estimation = Wholebody(device)
45
+ return self
46
+
47
+ def cal_height(self, input_image):
48
+ input_image = cv2.cvtColor(
49
+ np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
50
+ )
51
+
52
+ input_image = HWC3(input_image)
53
+ H, W, C = input_image.shape
54
+ with torch.no_grad():
55
+ candidate, subset = self.pose_estimation(input_image)
56
+ nums, keys, locs = candidate.shape
57
+ # candidate[..., 0] /= float(W)
58
+ # candidate[..., 1] /= float(H)
59
+ body = candidate
60
+ return body[0, ..., 1].min(), body[..., 1].max() - body[..., 1].min()
61
+
62
+ def __call__(
63
+ self,
64
+ input_image,
65
+ detect_resolution=512,
66
+ image_resolution=512,
67
+ output_type="pil",
68
+ **kwargs,
69
+ ):
70
+ input_image = cv2.cvtColor(
71
+ np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR
72
+ )
73
+
74
+ input_image = HWC3(input_image)
75
+ input_image = resize_image(input_image, detect_resolution)
76
+ H, W, C = input_image.shape
77
+ with torch.no_grad():
78
+ candidate, subset = self.pose_estimation(input_image)
79
+ nums, keys, locs = candidate.shape
80
+ candidate[..., 0] /= float(W)
81
+ candidate[..., 1] /= float(H)
82
+ score = subset[:, :18]
83
+ max_ind = np.mean(score, axis=-1).argmax(axis=0)
84
+ score = score[[max_ind]]
85
+ body = candidate[:, :18].copy()
86
+ body = body[[max_ind]]
87
+ nums = 1
88
+ body = body.reshape(nums * 18, locs)
89
+ body_score = copy.deepcopy(score)
90
+ for i in range(len(score)):
91
+ for j in range(len(score[i])):
92
+ if score[i][j] > 0.3:
93
+ score[i][j] = int(18 * i + j)
94
+ else:
95
+ score[i][j] = -1
96
+
97
+ un_visible = subset < 0.3
98
+ candidate[un_visible] = -1
99
+
100
+ foot = candidate[:, 18:24]
101
+
102
+ faces = candidate[[max_ind], 24:92]
103
+
104
+ hands = candidate[[max_ind], 92:113]
105
+ hands = np.vstack([hands, candidate[[max_ind], 113:]])
106
+
107
+ bodies = dict(candidate=body, subset=score)
108
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
109
+
110
+ detected_map = draw_pose(pose, H, W)
111
+ detected_map = HWC3(detected_map)
112
+
113
+ img = resize_image(input_image, image_resolution)
114
+ H, W, C = img.shape
115
+
116
+ detected_map = cv2.resize(
117
+ detected_map, (W, H), interpolation=cv2.INTER_LINEAR
118
+ )
119
+
120
+ if output_type == "pil":
121
+ detected_map = Image.fromarray(detected_map)
122
+
123
+ return detected_map, body_score
src/dwpose/onnxdet.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ import cv2
3
+ import numpy as np
4
+ import onnxruntime
5
+
6
+
7
+ def nms(boxes, scores, nms_thr):
8
+ """Single class NMS implemented in Numpy."""
9
+ x1 = boxes[:, 0]
10
+ y1 = boxes[:, 1]
11
+ x2 = boxes[:, 2]
12
+ y2 = boxes[:, 3]
13
+
14
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
15
+ order = scores.argsort()[::-1]
16
+
17
+ keep = []
18
+ while order.size > 0:
19
+ i = order[0]
20
+ keep.append(i)
21
+ xx1 = np.maximum(x1[i], x1[order[1:]])
22
+ yy1 = np.maximum(y1[i], y1[order[1:]])
23
+ xx2 = np.minimum(x2[i], x2[order[1:]])
24
+ yy2 = np.minimum(y2[i], y2[order[1:]])
25
+
26
+ w = np.maximum(0.0, xx2 - xx1 + 1)
27
+ h = np.maximum(0.0, yy2 - yy1 + 1)
28
+ inter = w * h
29
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
30
+
31
+ inds = np.where(ovr <= nms_thr)[0]
32
+ order = order[inds + 1]
33
+
34
+ return keep
35
+
36
+
37
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
38
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
39
+ final_dets = []
40
+ num_classes = scores.shape[1]
41
+ for cls_ind in range(num_classes):
42
+ cls_scores = scores[:, cls_ind]
43
+ valid_score_mask = cls_scores > score_thr
44
+ if valid_score_mask.sum() == 0:
45
+ continue
46
+ else:
47
+ valid_scores = cls_scores[valid_score_mask]
48
+ valid_boxes = boxes[valid_score_mask]
49
+ keep = nms(valid_boxes, valid_scores, nms_thr)
50
+ if len(keep) > 0:
51
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
52
+ dets = np.concatenate(
53
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
54
+ )
55
+ final_dets.append(dets)
56
+ if len(final_dets) == 0:
57
+ return None
58
+ return np.concatenate(final_dets, 0)
59
+
60
+
61
+ def demo_postprocess(outputs, img_size, p6=False):
62
+ grids = []
63
+ expanded_strides = []
64
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
65
+
66
+ hsizes = [img_size[0] // stride for stride in strides]
67
+ wsizes = [img_size[1] // stride for stride in strides]
68
+
69
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
70
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
71
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
72
+ grids.append(grid)
73
+ shape = grid.shape[:2]
74
+ expanded_strides.append(np.full((*shape, 1), stride))
75
+
76
+ grids = np.concatenate(grids, 1)
77
+ expanded_strides = np.concatenate(expanded_strides, 1)
78
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
79
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
80
+
81
+ return outputs
82
+
83
+
84
+ def preprocess(img, input_size, swap=(2, 0, 1)):
85
+ if len(img.shape) == 3:
86
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
87
+ else:
88
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
89
+
90
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
91
+ resized_img = cv2.resize(
92
+ img,
93
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
94
+ interpolation=cv2.INTER_LINEAR,
95
+ ).astype(np.uint8)
96
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
97
+
98
+ padded_img = padded_img.transpose(swap)
99
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
100
+ return padded_img, r
101
+
102
+
103
+ def inference_detector(session, oriImg):
104
+ input_shape = (640, 640)
105
+ img, ratio = preprocess(oriImg, input_shape)
106
+
107
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
108
+ output = session.run(None, ort_inputs)
109
+ predictions = demo_postprocess(output[0], input_shape)[0]
110
+
111
+ boxes = predictions[:, :4]
112
+ scores = predictions[:, 4:5] * predictions[:, 5:]
113
+
114
+ boxes_xyxy = np.ones_like(boxes)
115
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
116
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
117
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
118
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
119
+ boxes_xyxy /= ratio
120
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
121
+ if dets is not None:
122
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
123
+ isscore = final_scores > 0.3
124
+ iscat = final_cls_inds == 0
125
+ isbbox = [i and j for (i, j) in zip(isscore, iscat)]
126
+ final_boxes = final_boxes[isbbox]
127
+ else:
128
+ return []
129
+
130
+ return final_boxes
src/dwpose/onnxpose.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ from typing import List, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+
8
+
9
+ def preprocess(
10
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
11
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
12
+ """Do preprocessing for RTMPose model inference.
13
+
14
+ Args:
15
+ img (np.ndarray): Input image in shape.
16
+ input_size (tuple): Input image size in shape (w, h).
17
+
18
+ Returns:
19
+ tuple:
20
+ - resized_img (np.ndarray): Preprocessed image.
21
+ - center (np.ndarray): Center of image.
22
+ - scale (np.ndarray): Scale of image.
23
+ """
24
+ # get shape of image
25
+ img_shape = img.shape[:2]
26
+ out_img, out_center, out_scale = [], [], []
27
+ if len(out_bbox) == 0:
28
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
29
+ for i in range(len(out_bbox)):
30
+ x0 = out_bbox[i][0]
31
+ y0 = out_bbox[i][1]
32
+ x1 = out_bbox[i][2]
33
+ y1 = out_bbox[i][3]
34
+ bbox = np.array([x0, y0, x1, y1])
35
+
36
+ # get center and scale
37
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
38
+
39
+ # do affine transformation
40
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
41
+
42
+ # normalize image
43
+ mean = np.array([123.675, 116.28, 103.53])
44
+ std = np.array([58.395, 57.12, 57.375])
45
+ resized_img = (resized_img - mean) / std
46
+
47
+ out_img.append(resized_img)
48
+ out_center.append(center)
49
+ out_scale.append(scale)
50
+
51
+ return out_img, out_center, out_scale
52
+
53
+
54
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
55
+ """Inference RTMPose model.
56
+
57
+ Args:
58
+ sess (ort.InferenceSession): ONNXRuntime session.
59
+ img (np.ndarray): Input image in shape.
60
+
61
+ Returns:
62
+ outputs (np.ndarray): Output of RTMPose model.
63
+ """
64
+ all_out = []
65
+ # build input
66
+ for i in range(len(img)):
67
+ input = [img[i].transpose(2, 0, 1)]
68
+
69
+ # build output
70
+ sess_input = {sess.get_inputs()[0].name: input}
71
+ sess_output = []
72
+ for out in sess.get_outputs():
73
+ sess_output.append(out.name)
74
+
75
+ # run model
76
+ outputs = sess.run(sess_output, sess_input)
77
+ all_out.append(outputs)
78
+
79
+ return all_out
80
+
81
+
82
+ def postprocess(
83
+ outputs: List[np.ndarray],
84
+ model_input_size: Tuple[int, int],
85
+ center: Tuple[int, int],
86
+ scale: Tuple[int, int],
87
+ simcc_split_ratio: float = 2.0,
88
+ ) -> Tuple[np.ndarray, np.ndarray]:
89
+ """Postprocess for RTMPose model output.
90
+
91
+ Args:
92
+ outputs (np.ndarray): Output of RTMPose model.
93
+ model_input_size (tuple): RTMPose model Input image size.
94
+ center (tuple): Center of bbox in shape (x, y).
95
+ scale (tuple): Scale of bbox in shape (w, h).
96
+ simcc_split_ratio (float): Split ratio of simcc.
97
+
98
+ Returns:
99
+ tuple:
100
+ - keypoints (np.ndarray): Rescaled keypoints.
101
+ - scores (np.ndarray): Model predict scores.
102
+ """
103
+ all_key = []
104
+ all_score = []
105
+ for i in range(len(outputs)):
106
+ # use simcc to decode
107
+ simcc_x, simcc_y = outputs[i]
108
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
109
+
110
+ # rescale keypoints
111
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
112
+ all_key.append(keypoints[0])
113
+ all_score.append(scores[0])
114
+
115
+ return np.array(all_key), np.array(all_score)
116
+
117
+
118
+ def bbox_xyxy2cs(
119
+ bbox: np.ndarray, padding: float = 1.0
120
+ ) -> Tuple[np.ndarray, np.ndarray]:
121
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
122
+
123
+ Args:
124
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
125
+ as (left, top, right, bottom)
126
+ padding (float): BBox padding factor that will be multilied to scale.
127
+ Default: 1.0
128
+
129
+ Returns:
130
+ tuple: A tuple containing center and scale.
131
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
132
+ (n, 2)
133
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
134
+ (n, 2)
135
+ """
136
+ # convert single bbox from (4, ) to (1, 4)
137
+ dim = bbox.ndim
138
+ if dim == 1:
139
+ bbox = bbox[None, :]
140
+
141
+ # get bbox center and scale
142
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
143
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
144
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
145
+
146
+ if dim == 1:
147
+ center = center[0]
148
+ scale = scale[0]
149
+
150
+ return center, scale
151
+
152
+
153
+ def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float) -> np.ndarray:
154
+ """Extend the scale to match the given aspect ratio.
155
+
156
+ Args:
157
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
158
+ aspect_ratio (float): The ratio of ``w/h``
159
+
160
+ Returns:
161
+ np.ndarray: The reshaped image scale in (2, )
162
+ """
163
+ w, h = np.hsplit(bbox_scale, [1])
164
+ bbox_scale = np.where(
165
+ w > h * aspect_ratio,
166
+ np.hstack([w, w / aspect_ratio]),
167
+ np.hstack([h * aspect_ratio, h]),
168
+ )
169
+ return bbox_scale
170
+
171
+
172
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
173
+ """Rotate a point by an angle.
174
+
175
+ Args:
176
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
177
+ angle_rad (float): rotation angle in radian
178
+
179
+ Returns:
180
+ np.ndarray: Rotated point in shape (2, )
181
+ """
182
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
183
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
184
+ return rot_mat @ pt
185
+
186
+
187
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
188
+ """To calculate the affine matrix, three pairs of points are required. This
189
+ function is used to get the 3rd point, given 2D points a & b.
190
+
191
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
192
+ anticlockwise, using b as the rotation center.
193
+
194
+ Args:
195
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
196
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
197
+
198
+ Returns:
199
+ np.ndarray: The 3rd point.
200
+ """
201
+ direction = a - b
202
+ c = b + np.r_[-direction[1], direction[0]]
203
+ return c
204
+
205
+
206
+ def get_warp_matrix(
207
+ center: np.ndarray,
208
+ scale: np.ndarray,
209
+ rot: float,
210
+ output_size: Tuple[int, int],
211
+ shift: Tuple[float, float] = (0.0, 0.0),
212
+ inv: bool = False,
213
+ ) -> np.ndarray:
214
+ """Calculate the affine transformation matrix that can warp the bbox area
215
+ in the input image to the output size.
216
+
217
+ Args:
218
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
219
+ scale (np.ndarray[2, ]): Scale of the bounding box
220
+ wrt [width, height].
221
+ rot (float): Rotation angle (degree).
222
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
223
+ destination heatmaps.
224
+ shift (0-100%): Shift translation ratio wrt the width/height.
225
+ Default (0., 0.).
226
+ inv (bool): Option to inverse the affine transform direction.
227
+ (inv=False: src->dst or inv=True: dst->src)
228
+
229
+ Returns:
230
+ np.ndarray: A 2x3 transformation matrix
231
+ """
232
+ shift = np.array(shift)
233
+ src_w = scale[0]
234
+ dst_w = output_size[0]
235
+ dst_h = output_size[1]
236
+
237
+ # compute transformation matrix
238
+ rot_rad = np.deg2rad(rot)
239
+ src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad)
240
+ dst_dir = np.array([0.0, dst_w * -0.5])
241
+
242
+ # get four corners of the src rectangle in the original image
243
+ src = np.zeros((3, 2), dtype=np.float32)
244
+ src[0, :] = center + scale * shift
245
+ src[1, :] = center + src_dir + scale * shift
246
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
247
+
248
+ # get four corners of the dst rectangle in the input image
249
+ dst = np.zeros((3, 2), dtype=np.float32)
250
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
251
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
252
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
253
+
254
+ if inv:
255
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
256
+ else:
257
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
258
+
259
+ return warp_mat
260
+
261
+
262
+ def top_down_affine(
263
+ input_size: dict, bbox_scale: dict, bbox_center: dict, img: np.ndarray
264
+ ) -> Tuple[np.ndarray, np.ndarray]:
265
+ """Get the bbox image as the model input by affine transform.
266
+
267
+ Args:
268
+ input_size (dict): The input size of the model.
269
+ bbox_scale (dict): The bbox scale of the img.
270
+ bbox_center (dict): The bbox center of the img.
271
+ img (np.ndarray): The original image.
272
+
273
+ Returns:
274
+ tuple: A tuple containing center and scale.
275
+ - np.ndarray[float32]: img after affine transform.
276
+ - np.ndarray[float32]: bbox scale after affine transform.
277
+ """
278
+ w, h = input_size
279
+ warp_size = (int(w), int(h))
280
+
281
+ # reshape bbox to fixed aspect ratio
282
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
283
+
284
+ # get the affine matrix
285
+ center = bbox_center
286
+ scale = bbox_scale
287
+ rot = 0
288
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
289
+
290
+ # do affine transform
291
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
292
+
293
+ return img, bbox_scale
294
+
295
+
296
+ def get_simcc_maximum(
297
+ simcc_x: np.ndarray, simcc_y: np.ndarray
298
+ ) -> Tuple[np.ndarray, np.ndarray]:
299
+ """Get maximum response location and value from simcc representations.
300
+
301
+ Note:
302
+ instance number: N
303
+ num_keypoints: K
304
+ heatmap height: H
305
+ heatmap width: W
306
+
307
+ Args:
308
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
309
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
310
+
311
+ Returns:
312
+ tuple:
313
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
314
+ (K, 2) or (N, K, 2)
315
+ - vals (np.ndarray): values of maximum heatmap responses in shape
316
+ (K,) or (N, K)
317
+ """
318
+ N, K, Wx = simcc_x.shape
319
+ simcc_x = simcc_x.reshape(N * K, -1)
320
+ simcc_y = simcc_y.reshape(N * K, -1)
321
+
322
+ # get maximum value locations
323
+ x_locs = np.argmax(simcc_x, axis=1)
324
+ y_locs = np.argmax(simcc_y, axis=1)
325
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
326
+ max_val_x = np.amax(simcc_x, axis=1)
327
+ max_val_y = np.amax(simcc_y, axis=1)
328
+
329
+ # get maximum value across x and y axis
330
+ mask = max_val_x > max_val_y
331
+ max_val_x[mask] = max_val_y[mask]
332
+ vals = max_val_x
333
+ locs[vals <= 0.0] = -1
334
+
335
+ # reshape
336
+ locs = locs.reshape(N, K, 2)
337
+ vals = vals.reshape(N, K)
338
+
339
+ return locs, vals
340
+
341
+
342
+ def decode(
343
+ simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio
344
+ ) -> Tuple[np.ndarray, np.ndarray]:
345
+ """Modulate simcc distribution with Gaussian.
346
+
347
+ Args:
348
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
349
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
350
+ simcc_split_ratio (int): The split ratio of simcc.
351
+
352
+ Returns:
353
+ tuple: A tuple containing center and scale.
354
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
355
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
356
+ """
357
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
358
+ keypoints /= simcc_split_ratio
359
+
360
+ return keypoints, scores
361
+
362
+
363
+ def inference_pose(session, out_bbox, oriImg):
364
+ h, w = session.get_inputs()[0].shape[2:]
365
+ model_input_size = (w, h)
366
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
367
+ outputs = inference(session, resized_img)
368
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
369
+
370
+ return keypoints, scores
src/dwpose/util.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ import math
3
+ import numpy as np
4
+ import matplotlib
5
+ import cv2
6
+
7
+
8
+ eps = 0.01
9
+
10
+
11
+ def smart_resize(x, s):
12
+ Ht, Wt = s
13
+ if x.ndim == 2:
14
+ Ho, Wo = x.shape
15
+ Co = 1
16
+ else:
17
+ Ho, Wo, Co = x.shape
18
+ if Co == 3 or Co == 1:
19
+ k = float(Ht + Wt) / float(Ho + Wo)
20
+ return cv2.resize(
21
+ x,
22
+ (int(Wt), int(Ht)),
23
+ interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
24
+ )
25
+ else:
26
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
27
+
28
+
29
+ def smart_resize_k(x, fx, fy):
30
+ if x.ndim == 2:
31
+ Ho, Wo = x.shape
32
+ Co = 1
33
+ else:
34
+ Ho, Wo, Co = x.shape
35
+ Ht, Wt = Ho * fy, Wo * fx
36
+ if Co == 3 or Co == 1:
37
+ k = float(Ht + Wt) / float(Ho + Wo)
38
+ return cv2.resize(
39
+ x,
40
+ (int(Wt), int(Ht)),
41
+ interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4,
42
+ )
43
+ else:
44
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
45
+
46
+
47
+ def padRightDownCorner(img, stride, padValue):
48
+ h = img.shape[0]
49
+ w = img.shape[1]
50
+
51
+ pad = 4 * [None]
52
+ pad[0] = 0 # up
53
+ pad[1] = 0 # left
54
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
55
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
56
+
57
+ img_padded = img
58
+ pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
59
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
60
+ pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
61
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
62
+ pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
63
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
64
+ pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
65
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
66
+
67
+ return img_padded, pad
68
+
69
+
70
+ def transfer(model, model_weights):
71
+ transfered_model_weights = {}
72
+ for weights_name in model.state_dict().keys():
73
+ transfered_model_weights[weights_name] = model_weights[
74
+ ".".join(weights_name.split(".")[1:])
75
+ ]
76
+ return transfered_model_weights
77
+
78
+
79
+ def draw_bodypose(canvas, candidate, subset):
80
+ H, W, C = canvas.shape
81
+ candidate = np.array(candidate)
82
+ subset = np.array(subset)
83
+
84
+ stickwidth = 4
85
+
86
+ limbSeq = [
87
+ [2, 3],
88
+ [2, 6],
89
+ [3, 4],
90
+ [4, 5],
91
+ [6, 7],
92
+ [7, 8],
93
+ [2, 9],
94
+ [9, 10],
95
+ [10, 11],
96
+ [2, 12],
97
+ [12, 13],
98
+ [13, 14],
99
+ [2, 1],
100
+ [1, 15],
101
+ [15, 17],
102
+ [1, 16],
103
+ [16, 18],
104
+ [3, 17],
105
+ [6, 18],
106
+ ]
107
+
108
+ colors = [
109
+ [255, 0, 0],
110
+ [255, 85, 0],
111
+ [255, 170, 0],
112
+ [255, 255, 0],
113
+ [170, 255, 0],
114
+ [85, 255, 0],
115
+ [0, 255, 0],
116
+ [0, 255, 85],
117
+ [0, 255, 170],
118
+ [0, 255, 255],
119
+ [0, 170, 255],
120
+ [0, 85, 255],
121
+ [0, 0, 255],
122
+ [85, 0, 255],
123
+ [170, 0, 255],
124
+ [255, 0, 255],
125
+ [255, 0, 170],
126
+ [255, 0, 85],
127
+ ]
128
+
129
+ for i in range(17):
130
+ for n in range(len(subset)):
131
+ index = subset[n][np.array(limbSeq[i]) - 1]
132
+ if -1 in index:
133
+ continue
134
+ Y = candidate[index.astype(int), 0] * float(W)
135
+ X = candidate[index.astype(int), 1] * float(H)
136
+ mX = np.mean(X)
137
+ mY = np.mean(Y)
138
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
139
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
140
+ polygon = cv2.ellipse2Poly(
141
+ (int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1
142
+ )
143
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
144
+
145
+ canvas = (canvas * 0.6).astype(np.uint8)
146
+
147
+ for i in range(18):
148
+ for n in range(len(subset)):
149
+ index = int(subset[n][i])
150
+ if index == -1:
151
+ continue
152
+ x, y = candidate[index][0:2]
153
+ x = int(x * W)
154
+ y = int(y * H)
155
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
156
+
157
+ return canvas
158
+
159
+
160
+ def draw_handpose(canvas, all_hand_peaks):
161
+ H, W, C = canvas.shape
162
+
163
+ edges = [
164
+ [0, 1],
165
+ [1, 2],
166
+ [2, 3],
167
+ [3, 4],
168
+ [0, 5],
169
+ [5, 6],
170
+ [6, 7],
171
+ [7, 8],
172
+ [0, 9],
173
+ [9, 10],
174
+ [10, 11],
175
+ [11, 12],
176
+ [0, 13],
177
+ [13, 14],
178
+ [14, 15],
179
+ [15, 16],
180
+ [0, 17],
181
+ [17, 18],
182
+ [18, 19],
183
+ [19, 20],
184
+ ]
185
+
186
+ for peaks in all_hand_peaks:
187
+ peaks = np.array(peaks)
188
+
189
+ for ie, e in enumerate(edges):
190
+ x1, y1 = peaks[e[0]]
191
+ x2, y2 = peaks[e[1]]
192
+ x1 = int(x1 * W)
193
+ y1 = int(y1 * H)
194
+ x2 = int(x2 * W)
195
+ y2 = int(y2 * H)
196
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
197
+ cv2.line(
198
+ canvas,
199
+ (x1, y1),
200
+ (x2, y2),
201
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0])
202
+ * 255,
203
+ thickness=2,
204
+ )
205
+
206
+ for i, keyponit in enumerate(peaks):
207
+ x, y = keyponit
208
+ x = int(x * W)
209
+ y = int(y * H)
210
+ if x > eps and y > eps:
211
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
212
+ return canvas
213
+
214
+
215
+ def draw_facepose(canvas, all_lmks):
216
+ H, W, C = canvas.shape
217
+ for lmks in all_lmks:
218
+ lmks = np.array(lmks)
219
+ for lmk in lmks:
220
+ x, y = lmk
221
+ x = int(x * W)
222
+ y = int(y * H)
223
+ if x > eps and y > eps:
224
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
225
+ return canvas
226
+
227
+
228
+ # detect hand according to body pose keypoints
229
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
230
+ def handDetect(candidate, subset, oriImg):
231
+ # right hand: wrist 4, elbow 3, shoulder 2
232
+ # left hand: wrist 7, elbow 6, shoulder 5
233
+ ratioWristElbow = 0.33
234
+ detect_result = []
235
+ image_height, image_width = oriImg.shape[0:2]
236
+ for person in subset.astype(int):
237
+ # if any of three not detected
238
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
239
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
240
+ if not (has_left or has_right):
241
+ continue
242
+ hands = []
243
+ # left hand
244
+ if has_left:
245
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
246
+ x1, y1 = candidate[left_shoulder_index][:2]
247
+ x2, y2 = candidate[left_elbow_index][:2]
248
+ x3, y3 = candidate[left_wrist_index][:2]
249
+ hands.append([x1, y1, x2, y2, x3, y3, True])
250
+ # right hand
251
+ if has_right:
252
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[
253
+ [2, 3, 4]
254
+ ]
255
+ x1, y1 = candidate[right_shoulder_index][:2]
256
+ x2, y2 = candidate[right_elbow_index][:2]
257
+ x3, y3 = candidate[right_wrist_index][:2]
258
+ hands.append([x1, y1, x2, y2, x3, y3, False])
259
+
260
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
261
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
262
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
263
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
264
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
265
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
266
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
267
+ x = x3 + ratioWristElbow * (x3 - x2)
268
+ y = y3 + ratioWristElbow * (y3 - y2)
269
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
270
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
271
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
272
+ # x-y refers to the center --> offset to topLeft point
273
+ # handRectangle.x -= handRectangle.width / 2.f;
274
+ # handRectangle.y -= handRectangle.height / 2.f;
275
+ x -= width / 2
276
+ y -= width / 2 # width = height
277
+ # overflow the image
278
+ if x < 0:
279
+ x = 0
280
+ if y < 0:
281
+ y = 0
282
+ width1 = width
283
+ width2 = width
284
+ if x + width > image_width:
285
+ width1 = image_width - x
286
+ if y + width > image_height:
287
+ width2 = image_height - y
288
+ width = min(width1, width2)
289
+ # the max hand box value is 20 pixels
290
+ if width >= 20:
291
+ detect_result.append([int(x), int(y), int(width), is_left])
292
+
293
+ """
294
+ return value: [[x, y, w, True if left hand else False]].
295
+ width=height since the network require squared input.
296
+ x, y is the coordinate of top left
297
+ """
298
+ return detect_result
299
+
300
+
301
+ # Written by Lvmin
302
+ def faceDetect(candidate, subset, oriImg):
303
+ # left right eye ear 14 15 16 17
304
+ detect_result = []
305
+ image_height, image_width = oriImg.shape[0:2]
306
+ for person in subset.astype(int):
307
+ has_head = person[0] > -1
308
+ if not has_head:
309
+ continue
310
+
311
+ has_left_eye = person[14] > -1
312
+ has_right_eye = person[15] > -1
313
+ has_left_ear = person[16] > -1
314
+ has_right_ear = person[17] > -1
315
+
316
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
317
+ continue
318
+
319
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
320
+
321
+ width = 0.0
322
+ x0, y0 = candidate[head][:2]
323
+
324
+ if has_left_eye:
325
+ x1, y1 = candidate[left_eye][:2]
326
+ d = max(abs(x0 - x1), abs(y0 - y1))
327
+ width = max(width, d * 3.0)
328
+
329
+ if has_right_eye:
330
+ x1, y1 = candidate[right_eye][:2]
331
+ d = max(abs(x0 - x1), abs(y0 - y1))
332
+ width = max(width, d * 3.0)
333
+
334
+ if has_left_ear:
335
+ x1, y1 = candidate[left_ear][:2]
336
+ d = max(abs(x0 - x1), abs(y0 - y1))
337
+ width = max(width, d * 1.5)
338
+
339
+ if has_right_ear:
340
+ x1, y1 = candidate[right_ear][:2]
341
+ d = max(abs(x0 - x1), abs(y0 - y1))
342
+ width = max(width, d * 1.5)
343
+
344
+ x, y = x0, y0
345
+
346
+ x -= width
347
+ y -= width
348
+
349
+ if x < 0:
350
+ x = 0
351
+
352
+ if y < 0:
353
+ y = 0
354
+
355
+ width1 = width * 2
356
+ width2 = width * 2
357
+
358
+ if x + width > image_width:
359
+ width1 = image_width - x
360
+
361
+ if y + width > image_height:
362
+ width2 = image_height - y
363
+
364
+ width = min(width1, width2)
365
+
366
+ if width >= 20:
367
+ detect_result.append([int(x), int(y), int(width)])
368
+
369
+ return detect_result
370
+
371
+
372
+ # get max index of 2d array
373
+ def npmax(array):
374
+ arrayindex = array.argmax(1)
375
+ arrayvalue = array.max(1)
376
+ i = arrayvalue.argmax()
377
+ j = arrayindex[i]
378
+ return i, j
src/dwpose/wholebody.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/IDEA-Research/DWPose
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+
8
+ from .onnxdet import inference_detector
9
+ from .onnxpose import inference_pose
10
+
11
+ ModelDataPathPrefix = Path("./pretrained_weights")
12
+
13
+
14
+ class Wholebody:
15
+ def __init__(self, device="cuda:0"):
16
+ providers = (
17
+ ["CPUExecutionProvider"] if device == "cpu" else ["CUDAExecutionProvider"]
18
+ )
19
+ onnx_det = ModelDataPathPrefix.joinpath("DWPose/yolox_l.onnx")
20
+ onnx_pose = ModelDataPathPrefix.joinpath("DWPose/dw-ll_ucoco_384.onnx")
21
+
22
+ self.session_det = ort.InferenceSession(
23
+ path_or_bytes=onnx_det, providers=providers
24
+ )
25
+ self.session_pose = ort.InferenceSession(
26
+ path_or_bytes=onnx_pose, providers=providers
27
+ )
28
+
29
+ def __call__(self, oriImg):
30
+ det_result = inference_detector(self.session_det, oriImg)
31
+ keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
32
+
33
+ keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
34
+ # compute neck joint
35
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
36
+ # neck score when visualizing pred
37
+ neck[:, 2:4] = np.logical_and(
38
+ keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3
39
+ ).astype(int)
40
+ new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
41
+ mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
42
+ openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
43
+ new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
44
+ keypoints_info = new_keypoints_info
45
+
46
+ keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
47
+
48
+ return keypoints, scores
src/models/attention.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
7
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+
12
+ class BasicTransformerBlock(nn.Module):
13
+ r"""
14
+ A basic Transformer block.
15
+
16
+ Parameters:
17
+ dim (`int`): The number of channels in the input and output.
18
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
19
+ attention_head_dim (`int`): The number of channels in each head.
20
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23
+ num_embeds_ada_norm (:
24
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25
+ attention_bias (:
26
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27
+ only_cross_attention (`bool`, *optional*):
28
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
29
+ double_self_attention (`bool`, *optional*):
30
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
31
+ upcast_attention (`bool`, *optional*):
32
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34
+ Whether to use learnable elementwise affine parameters for normalization.
35
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37
+ final_dropout (`bool` *optional*, defaults to False):
38
+ Whether to apply a final dropout after the last feed-forward layer.
39
+ attention_type (`str`, *optional*, defaults to `"default"`):
40
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41
+ positional_embeddings (`str`, *optional*, defaults to `None`):
42
+ The type of positional embeddings to apply to.
43
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
44
+ The maximum number of positional embeddings to apply.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ num_attention_heads: int,
51
+ attention_head_dim: int,
52
+ dropout=0.0,
53
+ cross_attention_dim: Optional[int] = None,
54
+ activation_fn: str = "geglu",
55
+ num_embeds_ada_norm: Optional[int] = None,
56
+ attention_bias: bool = False,
57
+ only_cross_attention: bool = False,
58
+ double_self_attention: bool = False,
59
+ upcast_attention: bool = False,
60
+ norm_elementwise_affine: bool = True,
61
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62
+ norm_eps: float = 1e-5,
63
+ final_dropout: bool = False,
64
+ attention_type: str = "default",
65
+ positional_embeddings: Optional[str] = None,
66
+ num_positional_embeddings: Optional[int] = None,
67
+ ):
68
+ super().__init__()
69
+ self.only_cross_attention = only_cross_attention
70
+
71
+ self.use_ada_layer_norm_zero = (
72
+ num_embeds_ada_norm is not None
73
+ ) and norm_type == "ada_norm_zero"
74
+ self.use_ada_layer_norm = (
75
+ num_embeds_ada_norm is not None
76
+ ) and norm_type == "ada_norm"
77
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78
+ self.use_layer_norm = norm_type == "layer_norm"
79
+
80
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81
+ raise ValueError(
82
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84
+ )
85
+
86
+ if positional_embeddings and (num_positional_embeddings is None):
87
+ raise ValueError(
88
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89
+ )
90
+
91
+ if positional_embeddings == "sinusoidal":
92
+ self.pos_embed = SinusoidalPositionalEmbedding(
93
+ dim, max_seq_length=num_positional_embeddings
94
+ )
95
+ else:
96
+ self.pos_embed = None
97
+
98
+ # Define 3 blocks. Each block has its own normalization layer.
99
+ # 1. Self-Attn
100
+ if self.use_ada_layer_norm:
101
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ elif self.use_ada_layer_norm_zero:
103
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104
+ else:
105
+ self.norm1 = nn.LayerNorm(
106
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107
+ )
108
+
109
+ self.attn1 = Attention(
110
+ query_dim=dim,
111
+ heads=num_attention_heads,
112
+ dim_head=attention_head_dim,
113
+ dropout=dropout,
114
+ bias=attention_bias,
115
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116
+ upcast_attention=upcast_attention,
117
+ )
118
+
119
+ # 2. Cross-Attn
120
+ if cross_attention_dim is not None or double_self_attention:
121
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123
+ # the second cross attention block.
124
+ self.norm2 = (
125
+ AdaLayerNorm(dim, num_embeds_ada_norm)
126
+ if self.use_ada_layer_norm
127
+ else nn.LayerNorm(
128
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129
+ )
130
+ )
131
+ self.attn2 = Attention(
132
+ query_dim=dim,
133
+ cross_attention_dim=cross_attention_dim
134
+ if not double_self_attention
135
+ else None,
136
+ heads=num_attention_heads,
137
+ dim_head=attention_head_dim,
138
+ dropout=dropout,
139
+ bias=attention_bias,
140
+ upcast_attention=upcast_attention,
141
+ ) # is self-attn if encoder_hidden_states is none
142
+ else:
143
+ self.norm2 = None
144
+ self.attn2 = None
145
+
146
+ # 3. Feed-forward
147
+ if not self.use_ada_layer_norm_single:
148
+ self.norm3 = nn.LayerNorm(
149
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150
+ )
151
+
152
+ self.ff = FeedForward(
153
+ dim,
154
+ dropout=dropout,
155
+ activation_fn=activation_fn,
156
+ final_dropout=final_dropout,
157
+ )
158
+
159
+ # 4. Fuser
160
+ if attention_type == "gated" or attention_type == "gated-text-image":
161
+ self.fuser = GatedSelfAttentionDense(
162
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
163
+ )
164
+
165
+ # 5. Scale-shift for PixArt-Alpha.
166
+ if self.use_ada_layer_norm_single:
167
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168
+
169
+ # let chunk size default to None
170
+ self._chunk_size = None
171
+ self._chunk_dim = 0
172
+
173
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174
+ # Sets chunk feed-forward
175
+ self._chunk_size = chunk_size
176
+ self._chunk_dim = dim
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.FloatTensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
183
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
184
+ timestep: Optional[torch.LongTensor] = None,
185
+ cross_attention_kwargs: Dict[str, Any] = None,
186
+ class_labels: Optional[torch.LongTensor] = None,
187
+ ) -> torch.FloatTensor:
188
+ # Notice that normalization is always applied before the real computation in the following blocks.
189
+ # 0. Self-Attention
190
+ batch_size = hidden_states.shape[0]
191
+
192
+ if self.use_ada_layer_norm:
193
+ norm_hidden_states = self.norm1(hidden_states, timestep)
194
+ elif self.use_ada_layer_norm_zero:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197
+ )
198
+ elif self.use_layer_norm:
199
+ norm_hidden_states = self.norm1(hidden_states)
200
+ elif self.use_ada_layer_norm_single:
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203
+ ).chunk(6, dim=1)
204
+ norm_hidden_states = self.norm1(hidden_states)
205
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206
+ norm_hidden_states = norm_hidden_states.squeeze(1)
207
+ else:
208
+ raise ValueError("Incorrect norm used")
209
+
210
+ if self.pos_embed is not None:
211
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
212
+
213
+ # 1. Retrieve lora scale.
214
+ lora_scale = (
215
+ cross_attention_kwargs.get("scale", 1.0)
216
+ if cross_attention_kwargs is not None
217
+ else 1.0
218
+ )
219
+
220
+ # 2. Prepare GLIGEN inputs
221
+ cross_attention_kwargs = (
222
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223
+ )
224
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225
+
226
+ attn_output = self.attn1(
227
+ norm_hidden_states,
228
+ encoder_hidden_states=encoder_hidden_states
229
+ if self.only_cross_attention
230
+ else None,
231
+ attention_mask=attention_mask,
232
+ **cross_attention_kwargs,
233
+ )
234
+ if self.use_ada_layer_norm_zero:
235
+ attn_output = gate_msa.unsqueeze(1) * attn_output
236
+ elif self.use_ada_layer_norm_single:
237
+ attn_output = gate_msa * attn_output
238
+
239
+ hidden_states = attn_output + hidden_states
240
+ if hidden_states.ndim == 4:
241
+ hidden_states = hidden_states.squeeze(1)
242
+
243
+ # 2.5 GLIGEN Control
244
+ if gligen_kwargs is not None:
245
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246
+
247
+ # 3. Cross-Attention
248
+ if self.attn2 is not None:
249
+ if self.use_ada_layer_norm:
250
+ norm_hidden_states = self.norm2(hidden_states, timestep)
251
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252
+ norm_hidden_states = self.norm2(hidden_states)
253
+ elif self.use_ada_layer_norm_single:
254
+ # For PixArt norm2 isn't applied here:
255
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256
+ norm_hidden_states = hidden_states
257
+ else:
258
+ raise ValueError("Incorrect norm")
259
+
260
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
262
+
263
+ attn_output = self.attn2(
264
+ norm_hidden_states,
265
+ encoder_hidden_states=encoder_hidden_states,
266
+ attention_mask=encoder_attention_mask,
267
+ **cross_attention_kwargs,
268
+ )
269
+ hidden_states = attn_output + hidden_states
270
+
271
+ # 4. Feed-forward
272
+ if not self.use_ada_layer_norm_single:
273
+ norm_hidden_states = self.norm3(hidden_states)
274
+
275
+ if self.use_ada_layer_norm_zero:
276
+ norm_hidden_states = (
277
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278
+ )
279
+
280
+ if self.use_ada_layer_norm_single:
281
+ norm_hidden_states = self.norm2(hidden_states)
282
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283
+
284
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
288
+ elif self.use_ada_layer_norm_single:
289
+ ff_output = gate_mlp * ff_output
290
+
291
+ hidden_states = ff_output + hidden_states
292
+ if hidden_states.ndim == 4:
293
+ hidden_states = hidden_states.squeeze(1)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class TemporalBasicTransformerBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim: int,
302
+ num_attention_heads: int,
303
+ attention_head_dim: int,
304
+ dropout=0.0,
305
+ cross_attention_dim: Optional[int] = None,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ attention_bias: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ unet_use_cross_frame_attention=None,
312
+ unet_use_temporal_attention=None,
313
+ ):
314
+ super().__init__()
315
+ self.only_cross_attention = only_cross_attention
316
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
317
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
318
+ self.unet_use_temporal_attention = unet_use_temporal_attention
319
+
320
+ # SC-Attn
321
+ self.attn1 = Attention(
322
+ query_dim=dim,
323
+ heads=num_attention_heads,
324
+ dim_head=attention_head_dim,
325
+ dropout=dropout,
326
+ bias=attention_bias,
327
+ upcast_attention=upcast_attention,
328
+ )
329
+ self.norm1 = (
330
+ AdaLayerNorm(dim, num_embeds_ada_norm)
331
+ if self.use_ada_layer_norm
332
+ else nn.LayerNorm(dim)
333
+ )
334
+
335
+ # Cross-Attn
336
+ if cross_attention_dim is not None:
337
+ self.attn2 = Attention(
338
+ query_dim=dim,
339
+ cross_attention_dim=cross_attention_dim,
340
+ heads=num_attention_heads,
341
+ dim_head=attention_head_dim,
342
+ dropout=dropout,
343
+ bias=attention_bias,
344
+ upcast_attention=upcast_attention,
345
+ )
346
+ else:
347
+ self.attn2 = None
348
+
349
+ if cross_attention_dim is not None:
350
+ self.norm2 = (
351
+ AdaLayerNorm(dim, num_embeds_ada_norm)
352
+ if self.use_ada_layer_norm
353
+ else nn.LayerNorm(dim)
354
+ )
355
+ else:
356
+ self.norm2 = None
357
+
358
+ # Feed-forward
359
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
360
+ self.norm3 = nn.LayerNorm(dim)
361
+ self.use_ada_layer_norm_zero = False
362
+
363
+ # Temp-Attn
364
+ assert unet_use_temporal_attention is not None
365
+ if unet_use_temporal_attention:
366
+ self.attn_temp = Attention(
367
+ query_dim=dim,
368
+ heads=num_attention_heads,
369
+ dim_head=attention_head_dim,
370
+ dropout=dropout,
371
+ bias=attention_bias,
372
+ upcast_attention=upcast_attention,
373
+ )
374
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
375
+ self.norm_temp = (
376
+ AdaLayerNorm(dim, num_embeds_ada_norm)
377
+ if self.use_ada_layer_norm
378
+ else nn.LayerNorm(dim)
379
+ )
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states,
384
+ encoder_hidden_states=None,
385
+ timestep=None,
386
+ attention_mask=None,
387
+ video_length=None,
388
+ ):
389
+ norm_hidden_states = (
390
+ self.norm1(hidden_states, timestep)
391
+ if self.use_ada_layer_norm
392
+ else self.norm1(hidden_states)
393
+ )
394
+
395
+ if self.unet_use_cross_frame_attention:
396
+ hidden_states = (
397
+ self.attn1(
398
+ norm_hidden_states,
399
+ attention_mask=attention_mask,
400
+ video_length=video_length,
401
+ )
402
+ + hidden_states
403
+ )
404
+ else:
405
+ hidden_states = (
406
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
407
+ + hidden_states
408
+ )
409
+
410
+ if self.attn2 is not None:
411
+ # Cross-Attention
412
+ norm_hidden_states = (
413
+ self.norm2(hidden_states, timestep)
414
+ if self.use_ada_layer_norm
415
+ else self.norm2(hidden_states)
416
+ )
417
+ hidden_states = (
418
+ self.attn2(
419
+ norm_hidden_states,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ attention_mask=attention_mask,
422
+ )
423
+ + hidden_states
424
+ )
425
+
426
+ # Feed-forward
427
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
428
+
429
+ # Temporal-Attention
430
+ if self.unet_use_temporal_attention:
431
+ d = hidden_states.shape[1]
432
+ hidden_states = rearrange(
433
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
434
+ )
435
+ norm_hidden_states = (
436
+ self.norm_temp(hidden_states, timestep)
437
+ if self.use_ada_layer_norm
438
+ else self.norm_temp(hidden_states)
439
+ )
440
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
441
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
442
+
443
+ return hidden_states
src/models/motion_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from diffusers.models.attention import FeedForward
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor
9
+ from diffusers.utils import BaseOutput
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
20
+
21
+
22
+ @dataclass
23
+ class TemporalTransformer3DModelOutput(BaseOutput):
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35
+ if motion_module_type == "Vanilla":
36
+ return VanillaTemporalModule(
37
+ in_channels=in_channels,
38
+ **motion_module_kwargs,
39
+ )
40
+ else:
41
+ raise ValueError
42
+
43
+
44
+ class VanillaTemporalModule(nn.Module):
45
+ def __init__(
46
+ self,
47
+ in_channels,
48
+ num_attention_heads=8,
49
+ num_transformer_block=2,
50
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
51
+ cross_frame_attention_mode=None,
52
+ temporal_position_encoding=False,
53
+ temporal_position_encoding_max_len=24,
54
+ temporal_attention_dim_div=1,
55
+ zero_initialize=True,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.temporal_transformer = TemporalTransformer3DModel(
60
+ in_channels=in_channels,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=in_channels
63
+ // num_attention_heads
64
+ // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(
74
+ self.temporal_transformer.proj_out
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ input_tensor,
80
+ temb,
81
+ encoder_hidden_states,
82
+ attention_mask=None,
83
+ anchor_frame_idx=None,
84
+ ):
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(
87
+ hidden_states, encoder_hidden_states, attention_mask
88
+ )
89
+
90
+ output = hidden_states
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+ num_layers,
101
+ attention_block_types=(
102
+ "Temporal_Self",
103
+ "Temporal_Self",
104
+ ),
105
+ dropout=0.0,
106
+ norm_num_groups=32,
107
+ cross_attention_dim=768,
108
+ activation_fn="geglu",
109
+ attention_bias=False,
110
+ upcast_attention=False,
111
+ cross_frame_attention_mode=None,
112
+ temporal_position_encoding=False,
113
+ temporal_position_encoding_max_len=24,
114
+ ):
115
+ super().__init__()
116
+
117
+ inner_dim = num_attention_heads * attention_head_dim
118
+
119
+ self.norm = torch.nn.GroupNorm(
120
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121
+ )
122
+ self.proj_in = nn.Linear(in_channels, inner_dim)
123
+
124
+ self.transformer_blocks = nn.ModuleList(
125
+ [
126
+ TemporalTransformerBlock(
127
+ dim=inner_dim,
128
+ num_attention_heads=num_attention_heads,
129
+ attention_head_dim=attention_head_dim,
130
+ attention_block_types=attention_block_types,
131
+ dropout=dropout,
132
+ norm_num_groups=norm_num_groups,
133
+ cross_attention_dim=cross_attention_dim,
134
+ activation_fn=activation_fn,
135
+ attention_bias=attention_bias,
136
+ upcast_attention=upcast_attention,
137
+ cross_frame_attention_mode=cross_frame_attention_mode,
138
+ temporal_position_encoding=temporal_position_encoding,
139
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+ self.proj_out = nn.Linear(inner_dim, in_channels)
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147
+ assert (
148
+ hidden_states.dim() == 5
149
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150
+ video_length = hidden_states.shape[2]
151
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152
+
153
+ batch, channel, height, weight = hidden_states.shape
154
+ residual = hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ inner_dim = hidden_states.shape[1]
158
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159
+ batch, height * weight, inner_dim
160
+ )
161
+ hidden_states = self.proj_in(hidden_states)
162
+
163
+ # Transformer Blocks
164
+ for block in self.transformer_blocks:
165
+ hidden_states = block(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ video_length=video_length,
169
+ )
170
+
171
+ # output
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, weight, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+
179
+ output = hidden_states + residual
180
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181
+
182
+ return output
183
+
184
+
185
+ class TemporalTransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim,
189
+ num_attention_heads,
190
+ attention_head_dim,
191
+ attention_block_types=(
192
+ "Temporal_Self",
193
+ "Temporal_Self",
194
+ ),
195
+ dropout=0.0,
196
+ norm_num_groups=32,
197
+ cross_attention_dim=768,
198
+ activation_fn="geglu",
199
+ attention_bias=False,
200
+ upcast_attention=False,
201
+ cross_frame_attention_mode=None,
202
+ temporal_position_encoding=False,
203
+ temporal_position_encoding_max_len=24,
204
+ ):
205
+ super().__init__()
206
+
207
+ attention_blocks = []
208
+ norms = []
209
+
210
+ for block_name in attention_block_types:
211
+ attention_blocks.append(
212
+ VersatileAttention(
213
+ attention_mode=block_name.split("_")[0],
214
+ cross_attention_dim=cross_attention_dim
215
+ if block_name.endswith("_Cross")
216
+ else None,
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ cross_frame_attention_mode=cross_frame_attention_mode,
224
+ temporal_position_encoding=temporal_position_encoding,
225
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226
+ )
227
+ )
228
+ norms.append(nn.LayerNorm(dim))
229
+
230
+ self.attention_blocks = nn.ModuleList(attention_blocks)
231
+ self.norms = nn.ModuleList(norms)
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234
+ self.ff_norm = nn.LayerNorm(dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ video_length=None,
242
+ ):
243
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
244
+ norm_hidden_states = norm(hidden_states)
245
+ hidden_states = (
246
+ attention_block(
247
+ norm_hidden_states,
248
+ encoder_hidden_states=encoder_hidden_states
249
+ if attention_block.is_cross_attention
250
+ else None,
251
+ video_length=video_length,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257
+
258
+ output = hidden_states
259
+ return output
260
+
261
+
262
+ class PositionalEncoding(nn.Module):
263
+ def __init__(self, d_model, dropout=0.0, max_len=24):
264
+ super().__init__()
265
+ self.dropout = nn.Dropout(p=dropout)
266
+ position = torch.arange(max_len).unsqueeze(1)
267
+ div_term = torch.exp(
268
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269
+ )
270
+ pe = torch.zeros(1, max_len, d_model)
271
+ pe[0, :, 0::2] = torch.sin(position * div_term)
272
+ pe[0, :, 1::2] = torch.cos(position * div_term)
273
+ self.register_buffer("pe", pe)
274
+
275
+ def forward(self, x):
276
+ x = x + self.pe[:, : x.size(1)]
277
+ return self.dropout(x)
278
+
279
+
280
+ class VersatileAttention(Attention):
281
+ def __init__(
282
+ self,
283
+ attention_mode=None,
284
+ cross_frame_attention_mode=None,
285
+ temporal_position_encoding=False,
286
+ temporal_position_encoding_max_len=24,
287
+ *args,
288
+ **kwargs,
289
+ ):
290
+ super().__init__(*args, **kwargs)
291
+ assert attention_mode == "Temporal"
292
+
293
+ self.attention_mode = attention_mode
294
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295
+
296
+ self.pos_encoder = (
297
+ PositionalEncoding(
298
+ kwargs["query_dim"],
299
+ dropout=0.0,
300
+ max_len=temporal_position_encoding_max_len,
301
+ )
302
+ if (temporal_position_encoding and attention_mode == "Temporal")
303
+ else None
304
+ )
305
+
306
+ def extra_repr(self):
307
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308
+
309
+ def set_use_memory_efficient_attention_xformers(
310
+ self,
311
+ use_memory_efficient_attention_xformers: bool,
312
+ attention_op: Optional[Callable] = None,
313
+ ):
314
+ if use_memory_efficient_attention_xformers:
315
+ if not is_xformers_available():
316
+ raise ModuleNotFoundError(
317
+ (
318
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319
+ " xformers"
320
+ ),
321
+ name="xformers",
322
+ )
323
+ elif not torch.cuda.is_available():
324
+ raise ValueError(
325
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326
+ " only available for GPU "
327
+ )
328
+ else:
329
+ try:
330
+ # Make sure we can run the memory efficient attention
331
+ _ = xformers.ops.memory_efficient_attention(
332
+ torch.randn((1, 2, 40), device="cuda"),
333
+ torch.randn((1, 2, 40), device="cuda"),
334
+ torch.randn((1, 2, 40), device="cuda"),
335
+ )
336
+ except Exception as e:
337
+ raise e
338
+
339
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341
+ # You don't need XFormersAttnProcessor here.
342
+ # processor = XFormersAttnProcessor(
343
+ # attention_op=attention_op,
344
+ # )
345
+ processor = AttnProcessor()
346
+ else:
347
+ processor = AttnProcessor()
348
+
349
+ self.set_processor(processor)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=None,
355
+ attention_mask=None,
356
+ video_length=None,
357
+ **cross_attention_kwargs,
358
+ ):
359
+ if self.attention_mode == "Temporal":
360
+ d = hidden_states.shape[1] # d means HxW
361
+ hidden_states = rearrange(
362
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
363
+ )
364
+
365
+ if self.pos_encoder is not None:
366
+ hidden_states = self.pos_encoder(hidden_states)
367
+
368
+ encoder_hidden_states = (
369
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370
+ if encoder_hidden_states is not None
371
+ else encoder_hidden_states
372
+ )
373
+
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ hidden_states = self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ if self.attention_mode == "Temporal":
386
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387
+
388
+ return hidden_states
src/models/mutual_self_attention.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from src.models.attention import TemporalBasicTransformerBlock
8
+
9
+ from .attention import BasicTransformerBlock
10
+
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+
19
+ class ReferenceAttentionControl:
20
+ def __init__(
21
+ self,
22
+ unet,
23
+ mode="write",
24
+ do_classifier_free_guidance=False,
25
+ attention_auto_machine_weight=float("inf"),
26
+ gn_auto_machine_weight=1.0,
27
+ style_fidelity=1.0,
28
+ reference_attn=True,
29
+ reference_adain=False,
30
+ fusion_blocks="midup",
31
+ batch_size=1,
32
+ ) -> None:
33
+ # 10. Modify self attention and group norm
34
+ self.unet = unet
35
+ assert mode in ["read", "write"]
36
+ assert fusion_blocks in ["midup", "full"]
37
+ self.reference_attn = reference_attn
38
+ self.reference_adain = reference_adain
39
+ self.fusion_blocks = fusion_blocks
40
+ self.register_reference_hooks(
41
+ mode,
42
+ do_classifier_free_guidance,
43
+ attention_auto_machine_weight,
44
+ gn_auto_machine_weight,
45
+ style_fidelity,
46
+ reference_attn,
47
+ reference_adain,
48
+ fusion_blocks,
49
+ batch_size=batch_size,
50
+ )
51
+
52
+ def register_reference_hooks(
53
+ self,
54
+ mode,
55
+ do_classifier_free_guidance,
56
+ attention_auto_machine_weight,
57
+ gn_auto_machine_weight,
58
+ style_fidelity,
59
+ reference_attn,
60
+ reference_adain,
61
+ dtype=torch.float16,
62
+ batch_size=1,
63
+ num_images_per_prompt=1,
64
+ device=torch.device("cpu"),
65
+ fusion_blocks="midup",
66
+ ):
67
+ MODE = mode
68
+ do_classifier_free_guidance = do_classifier_free_guidance
69
+ attention_auto_machine_weight = attention_auto_machine_weight
70
+ gn_auto_machine_weight = gn_auto_machine_weight
71
+ style_fidelity = style_fidelity
72
+ reference_attn = reference_attn
73
+ reference_adain = reference_adain
74
+ fusion_blocks = fusion_blocks
75
+ num_images_per_prompt = num_images_per_prompt
76
+ dtype = dtype
77
+ if do_classifier_free_guidance:
78
+ uc_mask = (
79
+ torch.Tensor(
80
+ [1] * batch_size * num_images_per_prompt * 16
81
+ + [0] * batch_size * num_images_per_prompt * 16
82
+ )
83
+ .to(device)
84
+ .bool()
85
+ )
86
+ else:
87
+ uc_mask = (
88
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89
+ .to(device)
90
+ .bool()
91
+ )
92
+
93
+ def hacked_basic_transformer_inner_forward(
94
+ self,
95
+ hidden_states: torch.FloatTensor,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
98
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
99
+ timestep: Optional[torch.LongTensor] = None,
100
+ cross_attention_kwargs: Dict[str, Any] = None,
101
+ class_labels: Optional[torch.LongTensor] = None,
102
+ video_length=None,
103
+ ):
104
+ if self.use_ada_layer_norm: # False
105
+ norm_hidden_states = self.norm1(hidden_states, timestep)
106
+ elif self.use_ada_layer_norm_zero:
107
+ (
108
+ norm_hidden_states,
109
+ gate_msa,
110
+ shift_mlp,
111
+ scale_mlp,
112
+ gate_mlp,
113
+ ) = self.norm1(
114
+ hidden_states,
115
+ timestep,
116
+ class_labels,
117
+ hidden_dtype=hidden_states.dtype,
118
+ )
119
+ else:
120
+ norm_hidden_states = self.norm1(hidden_states)
121
+
122
+ # 1. Self-Attention
123
+ # self.only_cross_attention = False
124
+ cross_attention_kwargs = (
125
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
126
+ )
127
+ if self.only_cross_attention:
128
+ attn_output = self.attn1(
129
+ norm_hidden_states,
130
+ encoder_hidden_states=encoder_hidden_states
131
+ if self.only_cross_attention
132
+ else None,
133
+ attention_mask=attention_mask,
134
+ **cross_attention_kwargs,
135
+ )
136
+ else:
137
+ if MODE == "write":
138
+ self.bank.append(norm_hidden_states.clone())
139
+ attn_output = self.attn1(
140
+ norm_hidden_states,
141
+ encoder_hidden_states=encoder_hidden_states
142
+ if self.only_cross_attention
143
+ else None,
144
+ attention_mask=attention_mask,
145
+ **cross_attention_kwargs,
146
+ )
147
+ if MODE == "read":
148
+ bank_fea = [
149
+ rearrange(
150
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
151
+ "b t l c -> (b t) l c",
152
+ )
153
+ for d in self.bank
154
+ ]
155
+ modify_norm_hidden_states = torch.cat(
156
+ [norm_hidden_states] + bank_fea, dim=1
157
+ )
158
+ hidden_states_uc = (
159
+ self.attn1(
160
+ norm_hidden_states,
161
+ encoder_hidden_states=modify_norm_hidden_states,
162
+ attention_mask=attention_mask,
163
+ )
164
+ + hidden_states
165
+ )
166
+ if do_classifier_free_guidance:
167
+ hidden_states_c = hidden_states_uc.clone()
168
+ _uc_mask = uc_mask.clone()
169
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
170
+ _uc_mask = (
171
+ torch.Tensor(
172
+ [1] * (hidden_states.shape[0] // 2)
173
+ + [0] * (hidden_states.shape[0] // 2)
174
+ )
175
+ .to(device)
176
+ .bool()
177
+ )
178
+ hidden_states_c[_uc_mask] = (
179
+ self.attn1(
180
+ norm_hidden_states[_uc_mask],
181
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
182
+ attention_mask=attention_mask,
183
+ )
184
+ + hidden_states[_uc_mask]
185
+ )
186
+ hidden_states = hidden_states_c.clone()
187
+ else:
188
+ hidden_states = hidden_states_uc
189
+
190
+ # self.bank.clear()
191
+ if self.attn2 is not None:
192
+ # Cross-Attention
193
+ norm_hidden_states = (
194
+ self.norm2(hidden_states, timestep)
195
+ if self.use_ada_layer_norm
196
+ else self.norm2(hidden_states)
197
+ )
198
+ hidden_states = (
199
+ self.attn2(
200
+ norm_hidden_states,
201
+ encoder_hidden_states=encoder_hidden_states,
202
+ attention_mask=attention_mask,
203
+ )
204
+ + hidden_states
205
+ )
206
+
207
+ # Feed-forward
208
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209
+
210
+ # Temporal-Attention
211
+ if self.unet_use_temporal_attention:
212
+ d = hidden_states.shape[1]
213
+ hidden_states = rearrange(
214
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
215
+ )
216
+ norm_hidden_states = (
217
+ self.norm_temp(hidden_states, timestep)
218
+ if self.use_ada_layer_norm
219
+ else self.norm_temp(hidden_states)
220
+ )
221
+ hidden_states = (
222
+ self.attn_temp(norm_hidden_states) + hidden_states
223
+ )
224
+ hidden_states = rearrange(
225
+ hidden_states, "(b d) f c -> (b f) d c", d=d
226
+ )
227
+
228
+ return hidden_states
229
+
230
+ if self.use_ada_layer_norm_zero:
231
+ attn_output = gate_msa.unsqueeze(1) * attn_output
232
+ hidden_states = attn_output + hidden_states
233
+
234
+ if self.attn2 is not None:
235
+ norm_hidden_states = (
236
+ self.norm2(hidden_states, timestep)
237
+ if self.use_ada_layer_norm
238
+ else self.norm2(hidden_states)
239
+ )
240
+
241
+ # 2. Cross-Attention
242
+ attn_output = self.attn2(
243
+ norm_hidden_states,
244
+ encoder_hidden_states=encoder_hidden_states,
245
+ attention_mask=encoder_attention_mask,
246
+ **cross_attention_kwargs,
247
+ )
248
+ hidden_states = attn_output + hidden_states
249
+
250
+ # 3. Feed-forward
251
+ norm_hidden_states = self.norm3(hidden_states)
252
+
253
+ if self.use_ada_layer_norm_zero:
254
+ norm_hidden_states = (
255
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256
+ )
257
+
258
+ ff_output = self.ff(norm_hidden_states)
259
+
260
+ if self.use_ada_layer_norm_zero:
261
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
262
+
263
+ hidden_states = ff_output + hidden_states
264
+
265
+ return hidden_states
266
+
267
+ if self.reference_attn:
268
+ if self.fusion_blocks == "midup":
269
+ attn_modules = [
270
+ module
271
+ for module in (
272
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273
+ )
274
+ if isinstance(module, BasicTransformerBlock)
275
+ or isinstance(module, TemporalBasicTransformerBlock)
276
+ ]
277
+ elif self.fusion_blocks == "full":
278
+ attn_modules = [
279
+ module
280
+ for module in torch_dfs(self.unet)
281
+ if isinstance(module, BasicTransformerBlock)
282
+ or isinstance(module, TemporalBasicTransformerBlock)
283
+ ]
284
+ attn_modules = sorted(
285
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286
+ )
287
+
288
+ for i, module in enumerate(attn_modules):
289
+ module._original_inner_forward = module.forward
290
+ if isinstance(module, BasicTransformerBlock):
291
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
292
+ module, BasicTransformerBlock
293
+ )
294
+ if isinstance(module, TemporalBasicTransformerBlock):
295
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
296
+ module, TemporalBasicTransformerBlock
297
+ )
298
+
299
+ module.bank = []
300
+ module.attn_weight = float(i) / float(len(attn_modules))
301
+
302
+ def update(self, writer, dtype=torch.float16):
303
+ if self.reference_attn:
304
+ if self.fusion_blocks == "midup":
305
+ reader_attn_modules = [
306
+ module
307
+ for module in (
308
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309
+ )
310
+ if isinstance(module, TemporalBasicTransformerBlock)
311
+ ]
312
+ writer_attn_modules = [
313
+ module
314
+ for module in (
315
+ torch_dfs(writer.unet.mid_block)
316
+ + torch_dfs(writer.unet.up_blocks)
317
+ )
318
+ if isinstance(module, BasicTransformerBlock)
319
+ ]
320
+ elif self.fusion_blocks == "full":
321
+ reader_attn_modules = [
322
+ module
323
+ for module in torch_dfs(self.unet)
324
+ if isinstance(module, TemporalBasicTransformerBlock)
325
+ ]
326
+ writer_attn_modules = [
327
+ module
328
+ for module in torch_dfs(writer.unet)
329
+ if isinstance(module, BasicTransformerBlock)
330
+ ]
331
+ reader_attn_modules = sorted(
332
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333
+ )
334
+ writer_attn_modules = sorted(
335
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336
+ )
337
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
338
+ r.bank = [v.clone().to(dtype) for v in w.bank]
339
+ # w.bank.clear()
340
+
341
+ def clear(self):
342
+ if self.reference_attn:
343
+ if self.fusion_blocks == "midup":
344
+ reader_attn_modules = [
345
+ module
346
+ for module in (
347
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348
+ )
349
+ if isinstance(module, BasicTransformerBlock)
350
+ or isinstance(module, TemporalBasicTransformerBlock)
351
+ ]
352
+ elif self.fusion_blocks == "full":
353
+ reader_attn_modules = [
354
+ module
355
+ for module in torch_dfs(self.unet)
356
+ if isinstance(module, BasicTransformerBlock)
357
+ or isinstance(module, TemporalBasicTransformerBlock)
358
+ ]
359
+ reader_attn_modules = sorted(
360
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361
+ )
362
+ for r in reader_attn_modules:
363
+ r.bank.clear()
src/models/pose_guider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ from diffusers.models.modeling_utils import ModelMixin
7
+
8
+ from src.models.motion_module import zero_module
9
+ from src.models.resnet import InflatedConv3d
10
+
11
+
12
+ class PoseGuider(ModelMixin):
13
+ def __init__(
14
+ self,
15
+ conditioning_embedding_channels: int,
16
+ conditioning_channels: int = 3,
17
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
18
+ ):
19
+ super().__init__()
20
+ self.conv_in = InflatedConv3d(
21
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22
+ )
23
+
24
+ self.blocks = nn.ModuleList([])
25
+
26
+ for i in range(len(block_out_channels) - 1):
27
+ channel_in = block_out_channels[i]
28
+ channel_out = block_out_channels[i + 1]
29
+ self.blocks.append(
30
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31
+ )
32
+ self.blocks.append(
33
+ InflatedConv3d(
34
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
35
+ )
36
+ )
37
+
38
+ self.conv_out = zero_module(
39
+ InflatedConv3d(
40
+ block_out_channels[-1],
41
+ conditioning_embedding_channels,
42
+ kernel_size=3,
43
+ padding=1,
44
+ )
45
+ )
46
+
47
+ def forward(self, conditioning):
48
+ embedding = self.conv_in(conditioning)
49
+ embedding = F.silu(embedding)
50
+
51
+ for block in self.blocks:
52
+ embedding = block(embedding)
53
+ embedding = F.silu(embedding)
54
+
55
+ embedding = self.conv_out(embedding)
56
+
57
+ return embedding
src/models/resnet.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class InflatedConv3d(nn.Conv2d):
10
+ def forward(self, x):
11
+ video_length = x.shape[2]
12
+
13
+ x = rearrange(x, "b c f h w -> (b f) c h w")
14
+ x = super().forward(x)
15
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16
+
17
+ return x
18
+
19
+
20
+ class InflatedGroupNorm(nn.GroupNorm):
21
+ def forward(self, x):
22
+ video_length = x.shape[2]
23
+
24
+ x = rearrange(x, "b c f h w -> (b f) c h w")
25
+ x = super().forward(x)
26
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27
+
28
+ return x
29
+
30
+
31
+ class Upsample3D(nn.Module):
32
+ def __init__(
33
+ self,
34
+ channels,
35
+ use_conv=False,
36
+ use_conv_transpose=False,
37
+ out_channels=None,
38
+ name="conv",
39
+ ):
40
+ super().__init__()
41
+ self.channels = channels
42
+ self.out_channels = out_channels or channels
43
+ self.use_conv = use_conv
44
+ self.use_conv_transpose = use_conv_transpose
45
+ self.name = name
46
+
47
+ conv = None
48
+ if use_conv_transpose:
49
+ raise NotImplementedError
50
+ elif use_conv:
51
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52
+
53
+ def forward(self, hidden_states, output_size=None):
54
+ assert hidden_states.shape[1] == self.channels
55
+
56
+ if self.use_conv_transpose:
57
+ raise NotImplementedError
58
+
59
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60
+ dtype = hidden_states.dtype
61
+ if dtype == torch.bfloat16:
62
+ hidden_states = hidden_states.to(torch.float32)
63
+
64
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65
+ if hidden_states.shape[0] >= 64:
66
+ hidden_states = hidden_states.contiguous()
67
+
68
+ # if `output_size` is passed we force the interpolation output
69
+ # size and do not make use of `scale_factor=2`
70
+ if output_size is None:
71
+ hidden_states = F.interpolate(
72
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73
+ )
74
+ else:
75
+ hidden_states = F.interpolate(
76
+ hidden_states, size=output_size, mode="nearest"
77
+ )
78
+
79
+ # If the input is bfloat16, we cast back to bfloat16
80
+ if dtype == torch.bfloat16:
81
+ hidden_states = hidden_states.to(dtype)
82
+
83
+ # if self.use_conv:
84
+ # if self.name == "conv":
85
+ # hidden_states = self.conv(hidden_states)
86
+ # else:
87
+ # hidden_states = self.Conv2d_0(hidden_states)
88
+ hidden_states = self.conv(hidden_states)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class Downsample3D(nn.Module):
94
+ def __init__(
95
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96
+ ):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.out_channels = out_channels or channels
100
+ self.use_conv = use_conv
101
+ self.padding = padding
102
+ stride = 2
103
+ self.name = name
104
+
105
+ if use_conv:
106
+ self.conv = InflatedConv3d(
107
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
108
+ )
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ def forward(self, hidden_states):
113
+ assert hidden_states.shape[1] == self.channels
114
+ if self.use_conv and self.padding == 0:
115
+ raise NotImplementedError
116
+
117
+ assert hidden_states.shape[1] == self.channels
118
+ hidden_states = self.conv(hidden_states)
119
+
120
+ return hidden_states
121
+
122
+
123
+ class ResnetBlock3D(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ in_channels,
128
+ out_channels=None,
129
+ conv_shortcut=False,
130
+ dropout=0.0,
131
+ temb_channels=512,
132
+ groups=32,
133
+ groups_out=None,
134
+ pre_norm=True,
135
+ eps=1e-6,
136
+ non_linearity="swish",
137
+ time_embedding_norm="default",
138
+ output_scale_factor=1.0,
139
+ use_in_shortcut=None,
140
+ use_inflated_groupnorm=None,
141
+ ):
142
+ super().__init__()
143
+ self.pre_norm = pre_norm
144
+ self.pre_norm = True
145
+ self.in_channels = in_channels
146
+ out_channels = in_channels if out_channels is None else out_channels
147
+ self.out_channels = out_channels
148
+ self.use_conv_shortcut = conv_shortcut
149
+ self.time_embedding_norm = time_embedding_norm
150
+ self.output_scale_factor = output_scale_factor
151
+
152
+ if groups_out is None:
153
+ groups_out = groups
154
+
155
+ assert use_inflated_groupnorm != None
156
+ if use_inflated_groupnorm:
157
+ self.norm1 = InflatedGroupNorm(
158
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.norm1 = torch.nn.GroupNorm(
162
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163
+ )
164
+
165
+ self.conv1 = InflatedConv3d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+
169
+ if temb_channels is not None:
170
+ if self.time_embedding_norm == "default":
171
+ time_emb_proj_out_channels = out_channels
172
+ elif self.time_embedding_norm == "scale_shift":
173
+ time_emb_proj_out_channels = out_channels * 2
174
+ else:
175
+ raise ValueError(
176
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
177
+ )
178
+
179
+ self.time_emb_proj = torch.nn.Linear(
180
+ temb_channels, time_emb_proj_out_channels
181
+ )
182
+ else:
183
+ self.time_emb_proj = None
184
+
185
+ if use_inflated_groupnorm:
186
+ self.norm2 = InflatedGroupNorm(
187
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188
+ )
189
+ else:
190
+ self.norm2 = torch.nn.GroupNorm(
191
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192
+ )
193
+ self.dropout = torch.nn.Dropout(dropout)
194
+ self.conv2 = InflatedConv3d(
195
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
196
+ )
197
+
198
+ if non_linearity == "swish":
199
+ self.nonlinearity = lambda x: F.silu(x)
200
+ elif non_linearity == "mish":
201
+ self.nonlinearity = Mish()
202
+ elif non_linearity == "silu":
203
+ self.nonlinearity = nn.SiLU()
204
+
205
+ self.use_in_shortcut = (
206
+ self.in_channels != self.out_channels
207
+ if use_in_shortcut is None
208
+ else use_in_shortcut
209
+ )
210
+
211
+ self.conv_shortcut = None
212
+ if self.use_in_shortcut:
213
+ self.conv_shortcut = InflatedConv3d(
214
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
215
+ )
216
+
217
+ def forward(self, input_tensor, temb):
218
+ hidden_states = input_tensor
219
+
220
+ hidden_states = self.norm1(hidden_states)
221
+ hidden_states = self.nonlinearity(hidden_states)
222
+
223
+ hidden_states = self.conv1(hidden_states)
224
+
225
+ if temb is not None:
226
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227
+
228
+ if temb is not None and self.time_embedding_norm == "default":
229
+ hidden_states = hidden_states + temb
230
+
231
+ hidden_states = self.norm2(hidden_states)
232
+
233
+ if temb is not None and self.time_embedding_norm == "scale_shift":
234
+ scale, shift = torch.chunk(temb, 2, dim=1)
235
+ hidden_states = hidden_states * (1 + scale) + shift
236
+
237
+ hidden_states = self.nonlinearity(hidden_states)
238
+
239
+ hidden_states = self.dropout(hidden_states)
240
+ hidden_states = self.conv2(hidden_states)
241
+
242
+ if self.conv_shortcut is not None:
243
+ input_tensor = self.conv_shortcut(input_tensor)
244
+
245
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246
+
247
+ return output_tensor
248
+
249
+
250
+ class Mish(torch.nn.Module):
251
+ def forward(self, hidden_states):
252
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
src/models/transformer_2d.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.embeddings import CaptionProjection
8
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from diffusers.models.normalization import AdaLayerNormSingle
11
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
12
+ from torch import nn
13
+
14
+ from .attention import BasicTransformerBlock
15
+
16
+
17
+ @dataclass
18
+ class Transformer2DModelOutput(BaseOutput):
19
+ """
20
+ The output of [`Transformer2DModel`].
21
+
22
+ Args:
23
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
24
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
25
+ distributions for the unnoised latent pixels.
26
+ """
27
+
28
+ sample: torch.FloatTensor
29
+ ref_feature: torch.FloatTensor
30
+
31
+
32
+ class Transformer2DModel(ModelMixin, ConfigMixin):
33
+ """
34
+ A 2D Transformer model for image-like data.
35
+
36
+ Parameters:
37
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
38
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
39
+ in_channels (`int`, *optional*):
40
+ The number of channels in the input and output (specify if the input is **continuous**).
41
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
42
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
43
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
44
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
45
+ This is fixed during training since it is used to learn a number of position embeddings.
46
+ num_vector_embeds (`int`, *optional*):
47
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
48
+ Includes the class for the masked latent pixel.
49
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
50
+ num_embeds_ada_norm ( `int`, *optional*):
51
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
52
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
53
+ added to the hidden states.
54
+
55
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
56
+ attention_bias (`bool`, *optional*):
57
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
58
+ """
59
+
60
+ _supports_gradient_checkpointing = True
61
+
62
+ @register_to_config
63
+ def __init__(
64
+ self,
65
+ num_attention_heads: int = 16,
66
+ attention_head_dim: int = 88,
67
+ in_channels: Optional[int] = None,
68
+ out_channels: Optional[int] = None,
69
+ num_layers: int = 1,
70
+ dropout: float = 0.0,
71
+ norm_num_groups: int = 32,
72
+ cross_attention_dim: Optional[int] = None,
73
+ attention_bias: bool = False,
74
+ sample_size: Optional[int] = None,
75
+ num_vector_embeds: Optional[int] = None,
76
+ patch_size: Optional[int] = None,
77
+ activation_fn: str = "geglu",
78
+ num_embeds_ada_norm: Optional[int] = None,
79
+ use_linear_projection: bool = False,
80
+ only_cross_attention: bool = False,
81
+ double_self_attention: bool = False,
82
+ upcast_attention: bool = False,
83
+ norm_type: str = "layer_norm",
84
+ norm_elementwise_affine: bool = True,
85
+ norm_eps: float = 1e-5,
86
+ attention_type: str = "default",
87
+ caption_channels: int = None,
88
+ ):
89
+ super().__init__()
90
+ self.use_linear_projection = use_linear_projection
91
+ self.num_attention_heads = num_attention_heads
92
+ self.attention_head_dim = attention_head_dim
93
+ inner_dim = num_attention_heads * attention_head_dim
94
+
95
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
96
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
97
+
98
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
99
+ # Define whether input is continuous or discrete depending on configuration
100
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
101
+ self.is_input_vectorized = num_vector_embeds is not None
102
+ self.is_input_patches = in_channels is not None and patch_size is not None
103
+
104
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
105
+ deprecation_message = (
106
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
107
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
108
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
109
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
110
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
111
+ )
112
+ deprecate(
113
+ "norm_type!=num_embeds_ada_norm",
114
+ "1.0.0",
115
+ deprecation_message,
116
+ standard_warn=False,
117
+ )
118
+ norm_type = "ada_norm"
119
+
120
+ if self.is_input_continuous and self.is_input_vectorized:
121
+ raise ValueError(
122
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
123
+ " sure that either `in_channels` or `num_vector_embeds` is None."
124
+ )
125
+ elif self.is_input_vectorized and self.is_input_patches:
126
+ raise ValueError(
127
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
128
+ " sure that either `num_vector_embeds` or `num_patches` is None."
129
+ )
130
+ elif (
131
+ not self.is_input_continuous
132
+ and not self.is_input_vectorized
133
+ and not self.is_input_patches
134
+ ):
135
+ raise ValueError(
136
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
137
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
138
+ )
139
+
140
+ # 2. Define input layers
141
+ self.in_channels = in_channels
142
+
143
+ self.norm = torch.nn.GroupNorm(
144
+ num_groups=norm_num_groups,
145
+ num_channels=in_channels,
146
+ eps=1e-6,
147
+ affine=True,
148
+ )
149
+ if use_linear_projection:
150
+ self.proj_in = linear_cls(in_channels, inner_dim)
151
+ else:
152
+ self.proj_in = conv_cls(
153
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
154
+ )
155
+
156
+ # 3. Define transformers blocks
157
+ self.transformer_blocks = nn.ModuleList(
158
+ [
159
+ BasicTransformerBlock(
160
+ inner_dim,
161
+ num_attention_heads,
162
+ attention_head_dim,
163
+ dropout=dropout,
164
+ cross_attention_dim=cross_attention_dim,
165
+ activation_fn=activation_fn,
166
+ num_embeds_ada_norm=num_embeds_ada_norm,
167
+ attention_bias=attention_bias,
168
+ only_cross_attention=only_cross_attention,
169
+ double_self_attention=double_self_attention,
170
+ upcast_attention=upcast_attention,
171
+ norm_type=norm_type,
172
+ norm_elementwise_affine=norm_elementwise_affine,
173
+ norm_eps=norm_eps,
174
+ attention_type=attention_type,
175
+ )
176
+ for d in range(num_layers)
177
+ ]
178
+ )
179
+
180
+ # 4. Define output layers
181
+ self.out_channels = in_channels if out_channels is None else out_channels
182
+ # TODO: should use out_channels for continuous projections
183
+ if use_linear_projection:
184
+ self.proj_out = linear_cls(inner_dim, in_channels)
185
+ else:
186
+ self.proj_out = conv_cls(
187
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
188
+ )
189
+
190
+ # 5. PixArt-Alpha blocks.
191
+ self.adaln_single = None
192
+ self.use_additional_conditions = False
193
+ if norm_type == "ada_norm_single":
194
+ self.use_additional_conditions = self.config.sample_size == 128
195
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
196
+ # additional conditions until we find better name
197
+ self.adaln_single = AdaLayerNormSingle(
198
+ inner_dim, use_additional_conditions=self.use_additional_conditions
199
+ )
200
+
201
+ self.caption_projection = None
202
+ if caption_channels is not None:
203
+ self.caption_projection = CaptionProjection(
204
+ in_features=caption_channels, hidden_size=inner_dim
205
+ )
206
+
207
+ self.gradient_checkpointing = False
208
+
209
+ def _set_gradient_checkpointing(self, module, value=False):
210
+ if hasattr(module, "gradient_checkpointing"):
211
+ module.gradient_checkpointing = value
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ encoder_hidden_states: Optional[torch.Tensor] = None,
217
+ timestep: Optional[torch.LongTensor] = None,
218
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
219
+ class_labels: Optional[torch.LongTensor] = None,
220
+ cross_attention_kwargs: Dict[str, Any] = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ encoder_attention_mask: Optional[torch.Tensor] = None,
223
+ return_dict: bool = True,
224
+ ):
225
+ """
226
+ The [`Transformer2DModel`] forward method.
227
+
228
+ Args:
229
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
230
+ Input `hidden_states`.
231
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
232
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
233
+ self-attention.
234
+ timestep ( `torch.LongTensor`, *optional*):
235
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
236
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
237
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
238
+ `AdaLayerZeroNorm`.
239
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
240
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
241
+ `self.processor` in
242
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
243
+ attention_mask ( `torch.Tensor`, *optional*):
244
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
245
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
246
+ negative values to the attention scores corresponding to "discard" tokens.
247
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
248
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
249
+
250
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
251
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
252
+
253
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
254
+ above. This bias will be added to the cross-attention scores.
255
+ return_dict (`bool`, *optional*, defaults to `True`):
256
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
257
+ tuple.
258
+
259
+ Returns:
260
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
261
+ `tuple` where the first element is the sample tensor.
262
+ """
263
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
264
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
265
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
266
+ # expects mask of shape:
267
+ # [batch, key_tokens]
268
+ # adds singleton query_tokens dimension:
269
+ # [batch, 1, key_tokens]
270
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
271
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
272
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
273
+ if attention_mask is not None and attention_mask.ndim == 2:
274
+ # assume that mask is expressed as:
275
+ # (1 = keep, 0 = discard)
276
+ # convert mask into a bias that can be added to attention scores:
277
+ # (keep = +0, discard = -10000.0)
278
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
279
+ attention_mask = attention_mask.unsqueeze(1)
280
+
281
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
282
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
283
+ encoder_attention_mask = (
284
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
285
+ ) * -10000.0
286
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
287
+
288
+ # Retrieve lora scale.
289
+ lora_scale = (
290
+ cross_attention_kwargs.get("scale", 1.0)
291
+ if cross_attention_kwargs is not None
292
+ else 1.0
293
+ )
294
+
295
+ # 1. Input
296
+ batch, _, height, width = hidden_states.shape
297
+ residual = hidden_states
298
+
299
+ hidden_states = self.norm(hidden_states)
300
+ if not self.use_linear_projection:
301
+ hidden_states = (
302
+ self.proj_in(hidden_states, scale=lora_scale)
303
+ if not USE_PEFT_BACKEND
304
+ else self.proj_in(hidden_states)
305
+ )
306
+ inner_dim = hidden_states.shape[1]
307
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
308
+ batch, height * width, inner_dim
309
+ )
310
+ else:
311
+ inner_dim = hidden_states.shape[1]
312
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
313
+ batch, height * width, inner_dim
314
+ )
315
+ hidden_states = (
316
+ self.proj_in(hidden_states, scale=lora_scale)
317
+ if not USE_PEFT_BACKEND
318
+ else self.proj_in(hidden_states)
319
+ )
320
+
321
+ # 2. Blocks
322
+ if self.caption_projection is not None:
323
+ batch_size = hidden_states.shape[0]
324
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
325
+ encoder_hidden_states = encoder_hidden_states.view(
326
+ batch_size, -1, hidden_states.shape[-1]
327
+ )
328
+
329
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
330
+ for block in self.transformer_blocks:
331
+ if self.training and self.gradient_checkpointing:
332
+
333
+ def create_custom_forward(module, return_dict=None):
334
+ def custom_forward(*inputs):
335
+ if return_dict is not None:
336
+ return module(*inputs, return_dict=return_dict)
337
+ else:
338
+ return module(*inputs)
339
+
340
+ return custom_forward
341
+
342
+ ckpt_kwargs: Dict[str, Any] = (
343
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
+ )
345
+ hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward(block),
347
+ hidden_states,
348
+ attention_mask,
349
+ encoder_hidden_states,
350
+ encoder_attention_mask,
351
+ timestep,
352
+ cross_attention_kwargs,
353
+ class_labels,
354
+ **ckpt_kwargs,
355
+ )
356
+ else:
357
+ hidden_states = block(
358
+ hidden_states,
359
+ attention_mask=attention_mask,
360
+ encoder_hidden_states=encoder_hidden_states,
361
+ encoder_attention_mask=encoder_attention_mask,
362
+ timestep=timestep,
363
+ cross_attention_kwargs=cross_attention_kwargs,
364
+ class_labels=class_labels,
365
+ )
366
+
367
+ # 3. Output
368
+ if self.is_input_continuous:
369
+ if not self.use_linear_projection:
370
+ hidden_states = (
371
+ hidden_states.reshape(batch, height, width, inner_dim)
372
+ .permute(0, 3, 1, 2)
373
+ .contiguous()
374
+ )
375
+ hidden_states = (
376
+ self.proj_out(hidden_states, scale=lora_scale)
377
+ if not USE_PEFT_BACKEND
378
+ else self.proj_out(hidden_states)
379
+ )
380
+ else:
381
+ hidden_states = (
382
+ self.proj_out(hidden_states, scale=lora_scale)
383
+ if not USE_PEFT_BACKEND
384
+ else self.proj_out(hidden_states)
385
+ )
386
+ hidden_states = (
387
+ hidden_states.reshape(batch, height, width, inner_dim)
388
+ .permute(0, 3, 1, 2)
389
+ .contiguous()
390
+ )
391
+
392
+ output = hidden_states + residual
393
+ if not return_dict:
394
+ return (output, ref_feature)
395
+
396
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
src/models/transformer_3d.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from .attention import TemporalBasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer3DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+
19
+
20
+ if is_xformers_available():
21
+ import xformers
22
+ import xformers.ops
23
+ else:
24
+ xformers = None
25
+
26
+
27
+ class Transformer3DModel(ModelMixin, ConfigMixin):
28
+ _supports_gradient_checkpointing = True
29
+
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(
59
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60
+ )
61
+ if use_linear_projection:
62
+ self.proj_in = nn.Linear(in_channels, inner_dim)
63
+ else:
64
+ self.proj_in = nn.Conv2d(
65
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66
+ )
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ TemporalBasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(
94
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95
+ )
96
+
97
+ self.gradient_checkpointing = False
98
+
99
+ def _set_gradient_checkpointing(self, module, value=False):
100
+ if hasattr(module, "gradient_checkpointing"):
101
+ module.gradient_checkpointing = value
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states,
106
+ encoder_hidden_states=None,
107
+ timestep=None,
108
+ return_dict: bool = True,
109
+ ):
110
+ # Input
111
+ assert (
112
+ hidden_states.dim() == 5
113
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114
+ video_length = hidden_states.shape[2]
115
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117
+ encoder_hidden_states = repeat(
118
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119
+ )
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129
+ batch, height * weight, inner_dim
130
+ )
131
+ else:
132
+ inner_dim = hidden_states.shape[1]
133
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134
+ batch, height * weight, inner_dim
135
+ )
136
+ hidden_states = self.proj_in(hidden_states)
137
+
138
+ # Blocks
139
+ for i, block in enumerate(self.transformer_blocks):
140
+ hidden_states = block(
141
+ hidden_states,
142
+ encoder_hidden_states=encoder_hidden_states,
143
+ timestep=timestep,
144
+ video_length=video_length,
145
+ )
146
+
147
+ # Output
148
+ if not self.use_linear_projection:
149
+ hidden_states = (
150
+ hidden_states.reshape(batch, height, weight, inner_dim)
151
+ .permute(0, 3, 1, 2)
152
+ .contiguous()
153
+ )
154
+ hidden_states = self.proj_out(hidden_states)
155
+ else:
156
+ hidden_states = self.proj_out(hidden_states)
157
+ hidden_states = (
158
+ hidden_states.reshape(batch, height, weight, inner_dim)
159
+ .permute(0, 3, 1, 2)
160
+ .contiguous()
161
+ )
162
+
163
+ output = hidden_states + residual
164
+
165
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166
+ if not return_dict:
167
+ return (output,)
168
+
169
+ return Transformer3DModelOutput(sample=output)
src/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import get_activation
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
10
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
11
+ from diffusers.utils import is_torch_version, logging
12
+ from diffusers.utils.torch_utils import apply_freeu
13
+ from torch import nn
14
+
15
+ from .transformer_2d import Transformer2DModel
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ def get_down_block(
21
+ down_block_type: str,
22
+ num_layers: int,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ temb_channels: int,
26
+ add_downsample: bool,
27
+ resnet_eps: float,
28
+ resnet_act_fn: str,
29
+ transformer_layers_per_block: int = 1,
30
+ num_attention_heads: Optional[int] = None,
31
+ resnet_groups: Optional[int] = None,
32
+ cross_attention_dim: Optional[int] = None,
33
+ downsample_padding: Optional[int] = None,
34
+ dual_cross_attention: bool = False,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ resnet_time_scale_shift: str = "default",
39
+ attention_type: str = "default",
40
+ resnet_skip_time_act: bool = False,
41
+ resnet_out_scale_factor: float = 1.0,
42
+ cross_attention_norm: Optional[str] = None,
43
+ attention_head_dim: Optional[int] = None,
44
+ downsample_type: Optional[str] = None,
45
+ dropout: float = 0.0,
46
+ ):
47
+ # If attn head dim is not defined, we default it to the number of heads
48
+ if attention_head_dim is None:
49
+ logger.warn(
50
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
51
+ )
52
+ attention_head_dim = num_attention_heads
53
+
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock2D":
60
+ return DownBlock2D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ dropout=dropout,
66
+ add_downsample=add_downsample,
67
+ resnet_eps=resnet_eps,
68
+ resnet_act_fn=resnet_act_fn,
69
+ resnet_groups=resnet_groups,
70
+ downsample_padding=downsample_padding,
71
+ resnet_time_scale_shift=resnet_time_scale_shift,
72
+ )
73
+ elif down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError(
76
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
77
+ )
78
+ return CrossAttnDownBlock2D(
79
+ num_layers=num_layers,
80
+ transformer_layers_per_block=transformer_layers_per_block,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ dropout=dropout,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ attention_type=attention_type,
98
+ )
99
+ raise ValueError(f"{down_block_type} does not exist.")
100
+
101
+
102
+ def get_up_block(
103
+ up_block_type: str,
104
+ num_layers: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ prev_output_channel: int,
108
+ temb_channels: int,
109
+ add_upsample: bool,
110
+ resnet_eps: float,
111
+ resnet_act_fn: str,
112
+ resolution_idx: Optional[int] = None,
113
+ transformer_layers_per_block: int = 1,
114
+ num_attention_heads: Optional[int] = None,
115
+ resnet_groups: Optional[int] = None,
116
+ cross_attention_dim: Optional[int] = None,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ only_cross_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ resnet_time_scale_shift: str = "default",
122
+ attention_type: str = "default",
123
+ resnet_skip_time_act: bool = False,
124
+ resnet_out_scale_factor: float = 1.0,
125
+ cross_attention_norm: Optional[str] = None,
126
+ attention_head_dim: Optional[int] = None,
127
+ upsample_type: Optional[str] = None,
128
+ dropout: float = 0.0,
129
+ ) -> nn.Module:
130
+ # If attn head dim is not defined, we default it to the number of heads
131
+ if attention_head_dim is None:
132
+ logger.warn(
133
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
134
+ )
135
+ attention_head_dim = num_attention_heads
136
+
137
+ up_block_type = (
138
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
139
+ )
140
+ if up_block_type == "UpBlock2D":
141
+ return UpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ resolution_idx=resolution_idx,
148
+ dropout=dropout,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock2D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
159
+ )
160
+ return CrossAttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ transformer_layers_per_block=transformer_layers_per_block,
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ prev_output_channel=prev_output_channel,
166
+ temb_channels=temb_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_upsample=add_upsample,
170
+ resnet_eps=resnet_eps,
171
+ resnet_act_fn=resnet_act_fn,
172
+ resnet_groups=resnet_groups,
173
+ cross_attention_dim=cross_attention_dim,
174
+ num_attention_heads=num_attention_heads,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ only_cross_attention=only_cross_attention,
178
+ upcast_attention=upcast_attention,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attention_type=attention_type,
181
+ )
182
+
183
+ raise ValueError(f"{up_block_type} does not exist.")
184
+
185
+
186
+ class AutoencoderTinyBlock(nn.Module):
187
+ """
188
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
189
+ blocks.
190
+
191
+ Args:
192
+ in_channels (`int`): The number of input channels.
193
+ out_channels (`int`): The number of output channels.
194
+ act_fn (`str`):
195
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
196
+
197
+ Returns:
198
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
199
+ `out_channels`.
200
+ """
201
+
202
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
203
+ super().__init__()
204
+ act_fn = get_activation(act_fn)
205
+ self.conv = nn.Sequential(
206
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
207
+ act_fn,
208
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
209
+ act_fn,
210
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
211
+ )
212
+ self.skip = (
213
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
214
+ if in_channels != out_channels
215
+ else nn.Identity()
216
+ )
217
+ self.fuse = nn.ReLU()
218
+
219
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
220
+ return self.fuse(self.conv(x) + self.skip(x))
221
+
222
+
223
+ class UNetMidBlock2D(nn.Module):
224
+ """
225
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
226
+
227
+ Args:
228
+ in_channels (`int`): The number of input channels.
229
+ temb_channels (`int`): The number of temporal embedding channels.
230
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
231
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
232
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
233
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
234
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
235
+ model on tasks with long-range temporal dependencies.
236
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
237
+ resnet_groups (`int`, *optional*, defaults to 32):
238
+ The number of groups to use in the group normalization layers of the resnet blocks.
239
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
240
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
241
+ Whether to use pre-normalization for the resnet blocks.
242
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
243
+ attention_head_dim (`int`, *optional*, defaults to 1):
244
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
245
+ the number of input channels.
246
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
247
+
248
+ Returns:
249
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
250
+ in_channels, height, width)`.
251
+
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ in_channels: int,
257
+ temb_channels: int,
258
+ dropout: float = 0.0,
259
+ num_layers: int = 1,
260
+ resnet_eps: float = 1e-6,
261
+ resnet_time_scale_shift: str = "default", # default, spatial
262
+ resnet_act_fn: str = "swish",
263
+ resnet_groups: int = 32,
264
+ attn_groups: Optional[int] = None,
265
+ resnet_pre_norm: bool = True,
266
+ add_attention: bool = True,
267
+ attention_head_dim: int = 1,
268
+ output_scale_factor: float = 1.0,
269
+ ):
270
+ super().__init__()
271
+ resnet_groups = (
272
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
273
+ )
274
+ self.add_attention = add_attention
275
+
276
+ if attn_groups is None:
277
+ attn_groups = (
278
+ resnet_groups if resnet_time_scale_shift == "default" else None
279
+ )
280
+
281
+ # there is always at least one resnet
282
+ resnets = [
283
+ ResnetBlock2D(
284
+ in_channels=in_channels,
285
+ out_channels=in_channels,
286
+ temb_channels=temb_channels,
287
+ eps=resnet_eps,
288
+ groups=resnet_groups,
289
+ dropout=dropout,
290
+ time_embedding_norm=resnet_time_scale_shift,
291
+ non_linearity=resnet_act_fn,
292
+ output_scale_factor=output_scale_factor,
293
+ pre_norm=resnet_pre_norm,
294
+ )
295
+ ]
296
+ attentions = []
297
+
298
+ if attention_head_dim is None:
299
+ logger.warn(
300
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
301
+ )
302
+ attention_head_dim = in_channels
303
+
304
+ for _ in range(num_layers):
305
+ if self.add_attention:
306
+ attentions.append(
307
+ Attention(
308
+ in_channels,
309
+ heads=in_channels // attention_head_dim,
310
+ dim_head=attention_head_dim,
311
+ rescale_output_factor=output_scale_factor,
312
+ eps=resnet_eps,
313
+ norm_num_groups=attn_groups,
314
+ spatial_norm_dim=temb_channels
315
+ if resnet_time_scale_shift == "spatial"
316
+ else None,
317
+ residual_connection=True,
318
+ bias=True,
319
+ upcast_softmax=True,
320
+ _from_deprecated_attn_block=True,
321
+ )
322
+ )
323
+ else:
324
+ attentions.append(None)
325
+
326
+ resnets.append(
327
+ ResnetBlock2D(
328
+ in_channels=in_channels,
329
+ out_channels=in_channels,
330
+ temb_channels=temb_channels,
331
+ eps=resnet_eps,
332
+ groups=resnet_groups,
333
+ dropout=dropout,
334
+ time_embedding_norm=resnet_time_scale_shift,
335
+ non_linearity=resnet_act_fn,
336
+ output_scale_factor=output_scale_factor,
337
+ pre_norm=resnet_pre_norm,
338
+ )
339
+ )
340
+
341
+ self.attentions = nn.ModuleList(attentions)
342
+ self.resnets = nn.ModuleList(resnets)
343
+
344
+ def forward(
345
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
346
+ ) -> torch.FloatTensor:
347
+ hidden_states = self.resnets[0](hidden_states, temb)
348
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
349
+ if attn is not None:
350
+ hidden_states = attn(hidden_states, temb=temb)
351
+ hidden_states = resnet(hidden_states, temb)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class UNetMidBlock2DCrossAttn(nn.Module):
357
+ def __init__(
358
+ self,
359
+ in_channels: int,
360
+ temb_channels: int,
361
+ dropout: float = 0.0,
362
+ num_layers: int = 1,
363
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_time_scale_shift: str = "default",
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ num_attention_heads: int = 1,
370
+ output_scale_factor: float = 1.0,
371
+ cross_attention_dim: int = 1280,
372
+ dual_cross_attention: bool = False,
373
+ use_linear_projection: bool = False,
374
+ upcast_attention: bool = False,
375
+ attention_type: str = "default",
376
+ ):
377
+ super().__init__()
378
+
379
+ self.has_cross_attention = True
380
+ self.num_attention_heads = num_attention_heads
381
+ resnet_groups = (
382
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
383
+ )
384
+
385
+ # support for variable transformer layers per block
386
+ if isinstance(transformer_layers_per_block, int):
387
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
388
+
389
+ # there is always at least one resnet
390
+ resnets = [
391
+ ResnetBlock2D(
392
+ in_channels=in_channels,
393
+ out_channels=in_channels,
394
+ temb_channels=temb_channels,
395
+ eps=resnet_eps,
396
+ groups=resnet_groups,
397
+ dropout=dropout,
398
+ time_embedding_norm=resnet_time_scale_shift,
399
+ non_linearity=resnet_act_fn,
400
+ output_scale_factor=output_scale_factor,
401
+ pre_norm=resnet_pre_norm,
402
+ )
403
+ ]
404
+ attentions = []
405
+
406
+ for i in range(num_layers):
407
+ if not dual_cross_attention:
408
+ attentions.append(
409
+ Transformer2DModel(
410
+ num_attention_heads,
411
+ in_channels // num_attention_heads,
412
+ in_channels=in_channels,
413
+ num_layers=transformer_layers_per_block[i],
414
+ cross_attention_dim=cross_attention_dim,
415
+ norm_num_groups=resnet_groups,
416
+ use_linear_projection=use_linear_projection,
417
+ upcast_attention=upcast_attention,
418
+ attention_type=attention_type,
419
+ )
420
+ )
421
+ else:
422
+ attentions.append(
423
+ DualTransformer2DModel(
424
+ num_attention_heads,
425
+ in_channels // num_attention_heads,
426
+ in_channels=in_channels,
427
+ num_layers=1,
428
+ cross_attention_dim=cross_attention_dim,
429
+ norm_num_groups=resnet_groups,
430
+ )
431
+ )
432
+ resnets.append(
433
+ ResnetBlock2D(
434
+ in_channels=in_channels,
435
+ out_channels=in_channels,
436
+ temb_channels=temb_channels,
437
+ eps=resnet_eps,
438
+ groups=resnet_groups,
439
+ dropout=dropout,
440
+ time_embedding_norm=resnet_time_scale_shift,
441
+ non_linearity=resnet_act_fn,
442
+ output_scale_factor=output_scale_factor,
443
+ pre_norm=resnet_pre_norm,
444
+ )
445
+ )
446
+
447
+ self.attentions = nn.ModuleList(attentions)
448
+ self.resnets = nn.ModuleList(resnets)
449
+
450
+ self.gradient_checkpointing = False
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.FloatTensor,
455
+ temb: Optional[torch.FloatTensor] = None,
456
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
457
+ attention_mask: Optional[torch.FloatTensor] = None,
458
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
460
+ ) -> torch.FloatTensor:
461
+ lora_scale = (
462
+ cross_attention_kwargs.get("scale", 1.0)
463
+ if cross_attention_kwargs is not None
464
+ else 1.0
465
+ )
466
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
467
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
468
+ if self.training and self.gradient_checkpointing:
469
+
470
+ def create_custom_forward(module, return_dict=None):
471
+ def custom_forward(*inputs):
472
+ if return_dict is not None:
473
+ return module(*inputs, return_dict=return_dict)
474
+ else:
475
+ return module(*inputs)
476
+
477
+ return custom_forward
478
+
479
+ ckpt_kwargs: Dict[str, Any] = (
480
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
481
+ )
482
+ hidden_states, ref_feature = attn(
483
+ hidden_states,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ attention_mask=attention_mask,
487
+ encoder_attention_mask=encoder_attention_mask,
488
+ return_dict=False,
489
+ )
490
+ hidden_states = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(resnet),
492
+ hidden_states,
493
+ temb,
494
+ **ckpt_kwargs,
495
+ )
496
+ else:
497
+ hidden_states, ref_feature = attn(
498
+ hidden_states,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ cross_attention_kwargs=cross_attention_kwargs,
501
+ attention_mask=attention_mask,
502
+ encoder_attention_mask=encoder_attention_mask,
503
+ return_dict=False,
504
+ )
505
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class CrossAttnDownBlock2D(nn.Module):
511
+ def __init__(
512
+ self,
513
+ in_channels: int,
514
+ out_channels: int,
515
+ temb_channels: int,
516
+ dropout: float = 0.0,
517
+ num_layers: int = 1,
518
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
519
+ resnet_eps: float = 1e-6,
520
+ resnet_time_scale_shift: str = "default",
521
+ resnet_act_fn: str = "swish",
522
+ resnet_groups: int = 32,
523
+ resnet_pre_norm: bool = True,
524
+ num_attention_heads: int = 1,
525
+ cross_attention_dim: int = 1280,
526
+ output_scale_factor: float = 1.0,
527
+ downsample_padding: int = 1,
528
+ add_downsample: bool = True,
529
+ dual_cross_attention: bool = False,
530
+ use_linear_projection: bool = False,
531
+ only_cross_attention: bool = False,
532
+ upcast_attention: bool = False,
533
+ attention_type: str = "default",
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+
539
+ self.has_cross_attention = True
540
+ self.num_attention_heads = num_attention_heads
541
+ if isinstance(transformer_layers_per_block, int):
542
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
543
+
544
+ for i in range(num_layers):
545
+ in_channels = in_channels if i == 0 else out_channels
546
+ resnets.append(
547
+ ResnetBlock2D(
548
+ in_channels=in_channels,
549
+ out_channels=out_channels,
550
+ temb_channels=temb_channels,
551
+ eps=resnet_eps,
552
+ groups=resnet_groups,
553
+ dropout=dropout,
554
+ time_embedding_norm=resnet_time_scale_shift,
555
+ non_linearity=resnet_act_fn,
556
+ output_scale_factor=output_scale_factor,
557
+ pre_norm=resnet_pre_norm,
558
+ )
559
+ )
560
+ if not dual_cross_attention:
561
+ attentions.append(
562
+ Transformer2DModel(
563
+ num_attention_heads,
564
+ out_channels // num_attention_heads,
565
+ in_channels=out_channels,
566
+ num_layers=transformer_layers_per_block[i],
567
+ cross_attention_dim=cross_attention_dim,
568
+ norm_num_groups=resnet_groups,
569
+ use_linear_projection=use_linear_projection,
570
+ only_cross_attention=only_cross_attention,
571
+ upcast_attention=upcast_attention,
572
+ attention_type=attention_type,
573
+ )
574
+ )
575
+ else:
576
+ attentions.append(
577
+ DualTransformer2DModel(
578
+ num_attention_heads,
579
+ out_channels // num_attention_heads,
580
+ in_channels=out_channels,
581
+ num_layers=1,
582
+ cross_attention_dim=cross_attention_dim,
583
+ norm_num_groups=resnet_groups,
584
+ )
585
+ )
586
+ self.attentions = nn.ModuleList(attentions)
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ if add_downsample:
590
+ self.downsamplers = nn.ModuleList(
591
+ [
592
+ Downsample2D(
593
+ out_channels,
594
+ use_conv=True,
595
+ out_channels=out_channels,
596
+ padding=downsample_padding,
597
+ name="op",
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.downsamplers = None
603
+
604
+ self.gradient_checkpointing = False
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ additional_residuals: Optional[torch.FloatTensor] = None,
615
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
616
+ output_states = ()
617
+
618
+ lora_scale = (
619
+ cross_attention_kwargs.get("scale", 1.0)
620
+ if cross_attention_kwargs is not None
621
+ else 1.0
622
+ )
623
+
624
+ blocks = list(zip(self.resnets, self.attentions))
625
+
626
+ for i, (resnet, attn) in enumerate(blocks):
627
+ if self.training and self.gradient_checkpointing:
628
+
629
+ def create_custom_forward(module, return_dict=None):
630
+ def custom_forward(*inputs):
631
+ if return_dict is not None:
632
+ return module(*inputs, return_dict=return_dict)
633
+ else:
634
+ return module(*inputs)
635
+
636
+ return custom_forward
637
+
638
+ ckpt_kwargs: Dict[str, Any] = (
639
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
640
+ )
641
+ hidden_states = torch.utils.checkpoint.checkpoint(
642
+ create_custom_forward(resnet),
643
+ hidden_states,
644
+ temb,
645
+ **ckpt_kwargs,
646
+ )
647
+ hidden_states, ref_feature = attn(
648
+ hidden_states,
649
+ encoder_hidden_states=encoder_hidden_states,
650
+ cross_attention_kwargs=cross_attention_kwargs,
651
+ attention_mask=attention_mask,
652
+ encoder_attention_mask=encoder_attention_mask,
653
+ return_dict=False,
654
+ )
655
+ else:
656
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
657
+ hidden_states, ref_feature = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ cross_attention_kwargs=cross_attention_kwargs,
661
+ attention_mask=attention_mask,
662
+ encoder_attention_mask=encoder_attention_mask,
663
+ return_dict=False,
664
+ )
665
+
666
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
667
+ if i == len(blocks) - 1 and additional_residuals is not None:
668
+ hidden_states = hidden_states + additional_residuals
669
+
670
+ output_states = output_states + (hidden_states,)
671
+
672
+ if self.downsamplers is not None:
673
+ for downsampler in self.downsamplers:
674
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
675
+
676
+ output_states = output_states + (hidden_states,)
677
+
678
+ return hidden_states, output_states
679
+
680
+
681
+ class DownBlock2D(nn.Module):
682
+ def __init__(
683
+ self,
684
+ in_channels: int,
685
+ out_channels: int,
686
+ temb_channels: int,
687
+ dropout: float = 0.0,
688
+ num_layers: int = 1,
689
+ resnet_eps: float = 1e-6,
690
+ resnet_time_scale_shift: str = "default",
691
+ resnet_act_fn: str = "swish",
692
+ resnet_groups: int = 32,
693
+ resnet_pre_norm: bool = True,
694
+ output_scale_factor: float = 1.0,
695
+ add_downsample: bool = True,
696
+ downsample_padding: int = 1,
697
+ ):
698
+ super().__init__()
699
+ resnets = []
700
+
701
+ for i in range(num_layers):
702
+ in_channels = in_channels if i == 0 else out_channels
703
+ resnets.append(
704
+ ResnetBlock2D(
705
+ in_channels=in_channels,
706
+ out_channels=out_channels,
707
+ temb_channels=temb_channels,
708
+ eps=resnet_eps,
709
+ groups=resnet_groups,
710
+ dropout=dropout,
711
+ time_embedding_norm=resnet_time_scale_shift,
712
+ non_linearity=resnet_act_fn,
713
+ output_scale_factor=output_scale_factor,
714
+ pre_norm=resnet_pre_norm,
715
+ )
716
+ )
717
+
718
+ self.resnets = nn.ModuleList(resnets)
719
+
720
+ if add_downsample:
721
+ self.downsamplers = nn.ModuleList(
722
+ [
723
+ Downsample2D(
724
+ out_channels,
725
+ use_conv=True,
726
+ out_channels=out_channels,
727
+ padding=downsample_padding,
728
+ name="op",
729
+ )
730
+ ]
731
+ )
732
+ else:
733
+ self.downsamplers = None
734
+
735
+ self.gradient_checkpointing = False
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.FloatTensor,
740
+ temb: Optional[torch.FloatTensor] = None,
741
+ scale: float = 1.0,
742
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
743
+ output_states = ()
744
+
745
+ for resnet in self.resnets:
746
+ if self.training and self.gradient_checkpointing:
747
+
748
+ def create_custom_forward(module):
749
+ def custom_forward(*inputs):
750
+ return module(*inputs)
751
+
752
+ return custom_forward
753
+
754
+ if is_torch_version(">=", "1.11.0"):
755
+ hidden_states = torch.utils.checkpoint.checkpoint(
756
+ create_custom_forward(resnet),
757
+ hidden_states,
758
+ temb,
759
+ use_reentrant=False,
760
+ )
761
+ else:
762
+ hidden_states = torch.utils.checkpoint.checkpoint(
763
+ create_custom_forward(resnet), hidden_states, temb
764
+ )
765
+ else:
766
+ hidden_states = resnet(hidden_states, temb, scale=scale)
767
+
768
+ output_states = output_states + (hidden_states,)
769
+
770
+ if self.downsamplers is not None:
771
+ for downsampler in self.downsamplers:
772
+ hidden_states = downsampler(hidden_states, scale=scale)
773
+
774
+ output_states = output_states + (hidden_states,)
775
+
776
+ return hidden_states, output_states
777
+
778
+
779
+ class CrossAttnUpBlock2D(nn.Module):
780
+ def __init__(
781
+ self,
782
+ in_channels: int,
783
+ out_channels: int,
784
+ prev_output_channel: int,
785
+ temb_channels: int,
786
+ resolution_idx: Optional[int] = None,
787
+ dropout: float = 0.0,
788
+ num_layers: int = 1,
789
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
790
+ resnet_eps: float = 1e-6,
791
+ resnet_time_scale_shift: str = "default",
792
+ resnet_act_fn: str = "swish",
793
+ resnet_groups: int = 32,
794
+ resnet_pre_norm: bool = True,
795
+ num_attention_heads: int = 1,
796
+ cross_attention_dim: int = 1280,
797
+ output_scale_factor: float = 1.0,
798
+ add_upsample: bool = True,
799
+ dual_cross_attention: bool = False,
800
+ use_linear_projection: bool = False,
801
+ only_cross_attention: bool = False,
802
+ upcast_attention: bool = False,
803
+ attention_type: str = "default",
804
+ ):
805
+ super().__init__()
806
+ resnets = []
807
+ attentions = []
808
+
809
+ self.has_cross_attention = True
810
+ self.num_attention_heads = num_attention_heads
811
+
812
+ if isinstance(transformer_layers_per_block, int):
813
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
814
+
815
+ for i in range(num_layers):
816
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
817
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
818
+
819
+ resnets.append(
820
+ ResnetBlock2D(
821
+ in_channels=resnet_in_channels + res_skip_channels,
822
+ out_channels=out_channels,
823
+ temb_channels=temb_channels,
824
+ eps=resnet_eps,
825
+ groups=resnet_groups,
826
+ dropout=dropout,
827
+ time_embedding_norm=resnet_time_scale_shift,
828
+ non_linearity=resnet_act_fn,
829
+ output_scale_factor=output_scale_factor,
830
+ pre_norm=resnet_pre_norm,
831
+ )
832
+ )
833
+ if not dual_cross_attention:
834
+ attentions.append(
835
+ Transformer2DModel(
836
+ num_attention_heads,
837
+ out_channels // num_attention_heads,
838
+ in_channels=out_channels,
839
+ num_layers=transformer_layers_per_block[i],
840
+ cross_attention_dim=cross_attention_dim,
841
+ norm_num_groups=resnet_groups,
842
+ use_linear_projection=use_linear_projection,
843
+ only_cross_attention=only_cross_attention,
844
+ upcast_attention=upcast_attention,
845
+ attention_type=attention_type,
846
+ )
847
+ )
848
+ else:
849
+ attentions.append(
850
+ DualTransformer2DModel(
851
+ num_attention_heads,
852
+ out_channels // num_attention_heads,
853
+ in_channels=out_channels,
854
+ num_layers=1,
855
+ cross_attention_dim=cross_attention_dim,
856
+ norm_num_groups=resnet_groups,
857
+ )
858
+ )
859
+ self.attentions = nn.ModuleList(attentions)
860
+ self.resnets = nn.ModuleList(resnets)
861
+
862
+ if add_upsample:
863
+ self.upsamplers = nn.ModuleList(
864
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
865
+ )
866
+ else:
867
+ self.upsamplers = None
868
+
869
+ self.gradient_checkpointing = False
870
+ self.resolution_idx = resolution_idx
871
+
872
+ def forward(
873
+ self,
874
+ hidden_states: torch.FloatTensor,
875
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
876
+ temb: Optional[torch.FloatTensor] = None,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
879
+ upsample_size: Optional[int] = None,
880
+ attention_mask: Optional[torch.FloatTensor] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ ) -> torch.FloatTensor:
883
+ lora_scale = (
884
+ cross_attention_kwargs.get("scale", 1.0)
885
+ if cross_attention_kwargs is not None
886
+ else 1.0
887
+ )
888
+ is_freeu_enabled = (
889
+ getattr(self, "s1", None)
890
+ and getattr(self, "s2", None)
891
+ and getattr(self, "b1", None)
892
+ and getattr(self, "b2", None)
893
+ )
894
+
895
+ for resnet, attn in zip(self.resnets, self.attentions):
896
+ # pop res hidden states
897
+ res_hidden_states = res_hidden_states_tuple[-1]
898
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
899
+
900
+ # FreeU: Only operate on the first two stages
901
+ if is_freeu_enabled:
902
+ hidden_states, res_hidden_states = apply_freeu(
903
+ self.resolution_idx,
904
+ hidden_states,
905
+ res_hidden_states,
906
+ s1=self.s1,
907
+ s2=self.s2,
908
+ b1=self.b1,
909
+ b2=self.b2,
910
+ )
911
+
912
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
913
+
914
+ if self.training and self.gradient_checkpointing:
915
+
916
+ def create_custom_forward(module, return_dict=None):
917
+ def custom_forward(*inputs):
918
+ if return_dict is not None:
919
+ return module(*inputs, return_dict=return_dict)
920
+ else:
921
+ return module(*inputs)
922
+
923
+ return custom_forward
924
+
925
+ ckpt_kwargs: Dict[str, Any] = (
926
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
927
+ )
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states, ref_feature = attn(
935
+ hidden_states,
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ cross_attention_kwargs=cross_attention_kwargs,
938
+ attention_mask=attention_mask,
939
+ encoder_attention_mask=encoder_attention_mask,
940
+ return_dict=False,
941
+ )
942
+ else:
943
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
944
+ hidden_states, ref_feature = attn(
945
+ hidden_states,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ cross_attention_kwargs=cross_attention_kwargs,
948
+ attention_mask=attention_mask,
949
+ encoder_attention_mask=encoder_attention_mask,
950
+ return_dict=False,
951
+ )
952
+
953
+ if self.upsamplers is not None:
954
+ for upsampler in self.upsamplers:
955
+ hidden_states = upsampler(
956
+ hidden_states, upsample_size, scale=lora_scale
957
+ )
958
+
959
+ return hidden_states
960
+
961
+
962
+ class UpBlock2D(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ prev_output_channel: int,
967
+ out_channels: int,
968
+ temb_channels: int,
969
+ resolution_idx: Optional[int] = None,
970
+ dropout: float = 0.0,
971
+ num_layers: int = 1,
972
+ resnet_eps: float = 1e-6,
973
+ resnet_time_scale_shift: str = "default",
974
+ resnet_act_fn: str = "swish",
975
+ resnet_groups: int = 32,
976
+ resnet_pre_norm: bool = True,
977
+ output_scale_factor: float = 1.0,
978
+ add_upsample: bool = True,
979
+ ):
980
+ super().__init__()
981
+ resnets = []
982
+
983
+ for i in range(num_layers):
984
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
985
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
986
+
987
+ resnets.append(
988
+ ResnetBlock2D(
989
+ in_channels=resnet_in_channels + res_skip_channels,
990
+ out_channels=out_channels,
991
+ temb_channels=temb_channels,
992
+ eps=resnet_eps,
993
+ groups=resnet_groups,
994
+ dropout=dropout,
995
+ time_embedding_norm=resnet_time_scale_shift,
996
+ non_linearity=resnet_act_fn,
997
+ output_scale_factor=output_scale_factor,
998
+ pre_norm=resnet_pre_norm,
999
+ )
1000
+ )
1001
+
1002
+ self.resnets = nn.ModuleList(resnets)
1003
+
1004
+ if add_upsample:
1005
+ self.upsamplers = nn.ModuleList(
1006
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1007
+ )
1008
+ else:
1009
+ self.upsamplers = None
1010
+
1011
+ self.gradient_checkpointing = False
1012
+ self.resolution_idx = resolution_idx
1013
+
1014
+ def forward(
1015
+ self,
1016
+ hidden_states: torch.FloatTensor,
1017
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1018
+ temb: Optional[torch.FloatTensor] = None,
1019
+ upsample_size: Optional[int] = None,
1020
+ scale: float = 1.0,
1021
+ ) -> torch.FloatTensor:
1022
+ is_freeu_enabled = (
1023
+ getattr(self, "s1", None)
1024
+ and getattr(self, "s2", None)
1025
+ and getattr(self, "b1", None)
1026
+ and getattr(self, "b2", None)
1027
+ )
1028
+
1029
+ for resnet in self.resnets:
1030
+ # pop res hidden states
1031
+ res_hidden_states = res_hidden_states_tuple[-1]
1032
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1033
+
1034
+ # FreeU: Only operate on the first two stages
1035
+ if is_freeu_enabled:
1036
+ hidden_states, res_hidden_states = apply_freeu(
1037
+ self.resolution_idx,
1038
+ hidden_states,
1039
+ res_hidden_states,
1040
+ s1=self.s1,
1041
+ s2=self.s2,
1042
+ b1=self.b1,
1043
+ b2=self.b2,
1044
+ )
1045
+
1046
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1047
+
1048
+ if self.training and self.gradient_checkpointing:
1049
+
1050
+ def create_custom_forward(module):
1051
+ def custom_forward(*inputs):
1052
+ return module(*inputs)
1053
+
1054
+ return custom_forward
1055
+
1056
+ if is_torch_version(">=", "1.11.0"):
1057
+ hidden_states = torch.utils.checkpoint.checkpoint(
1058
+ create_custom_forward(resnet),
1059
+ hidden_states,
1060
+ temb,
1061
+ use_reentrant=False,
1062
+ )
1063
+ else:
1064
+ hidden_states = torch.utils.checkpoint.checkpoint(
1065
+ create_custom_forward(resnet), hidden_states, temb
1066
+ )
1067
+ else:
1068
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1069
+
1070
+ if self.upsamplers is not None:
1071
+ for upsampler in self.upsamplers:
1072
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1073
+
1074
+ return hidden_states
src/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ PositionNet,
24
+ TextImageProjection,
25
+ TextImageTimeEmbedding,
26
+ TextTimeEmbedding,
27
+ TimestepEmbedding,
28
+ Timesteps,
29
+ )
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from diffusers.utils import (
32
+ USE_PEFT_BACKEND,
33
+ BaseOutput,
34
+ deprecate,
35
+ logging,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+
40
+ from .unet_2d_blocks import (
41
+ UNetMidBlock2D,
42
+ UNetMidBlock2DCrossAttn,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ @dataclass
51
+ class UNet2DConditionOutput(BaseOutput):
52
+ """
53
+ The output of [`UNet2DConditionModel`].
54
+
55
+ Args:
56
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
57
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
58
+ """
59
+
60
+ sample: torch.FloatTensor = None
61
+ ref_features: Tuple[torch.FloatTensor] = None
62
+
63
+
64
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
65
+ r"""
66
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
67
+ shaped output.
68
+
69
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
70
+ for all models (such as downloading or saving).
71
+
72
+ Parameters:
73
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
74
+ Height and width of input/output sample.
75
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
76
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
77
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
78
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
79
+ Whether to flip the sin to cos in the time embedding.
80
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
81
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
82
+ The tuple of downsample blocks to use.
83
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
84
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
85
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
86
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
87
+ The tuple of upsample blocks to use.
88
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
89
+ Whether to include self-attention in the basic transformer blocks, see
90
+ [`~models.attention.BasicTransformerBlock`].
91
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
92
+ The tuple of output channels for each block.
93
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
94
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
95
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
96
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
97
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
98
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
99
+ If `None`, normalization and activation layers is skipped in post-processing.
100
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
101
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
102
+ The dimension of the cross attention features.
103
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
104
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
105
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
106
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
107
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
108
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
109
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
110
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ encoder_hid_dim (`int`, *optional*, defaults to None):
113
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
114
+ dimension to `cross_attention_dim`.
115
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
116
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
117
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
118
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
119
+ num_attention_heads (`int`, *optional*):
120
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
121
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
122
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
123
+ class_embed_type (`str`, *optional*, defaults to `None`):
124
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
125
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
126
+ addition_embed_type (`str`, *optional*, defaults to `None`):
127
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
128
+ "text". "text" will use the `TextTimeEmbedding` layer.
129
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
130
+ Dimension for the timestep embeddings.
131
+ num_class_embeds (`int`, *optional*, defaults to `None`):
132
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
133
+ class conditioning with `class_embed_type` equal to `None`.
134
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
135
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
136
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
137
+ An optional override for the dimension of the projected time embedding.
138
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
139
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
140
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
141
+ timestep_post_act (`str`, *optional*, defaults to `None`):
142
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
143
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
144
+ The dimension of `cond_proj` layer in the timestep embedding.
145
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
146
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
147
+ *optional*): The dimension of the `class_labels` input when
148
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
149
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
150
+ embeddings with the class embeddings.
151
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
152
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
153
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
154
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
155
+ otherwise.
156
+ """
157
+
158
+ _supports_gradient_checkpointing = True
159
+
160
+ @register_to_config
161
+ def __init__(
162
+ self,
163
+ sample_size: Optional[int] = None,
164
+ in_channels: int = 4,
165
+ out_channels: int = 4,
166
+ center_input_sample: bool = False,
167
+ flip_sin_to_cos: bool = True,
168
+ freq_shift: int = 0,
169
+ down_block_types: Tuple[str] = (
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "CrossAttnDownBlock2D",
173
+ "DownBlock2D",
174
+ ),
175
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
176
+ up_block_types: Tuple[str] = (
177
+ "UpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ "CrossAttnUpBlock2D",
181
+ ),
182
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
183
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
+ layers_per_block: Union[int, Tuple[int]] = 2,
185
+ downsample_padding: int = 1,
186
+ mid_block_scale_factor: float = 1,
187
+ dropout: float = 0.0,
188
+ act_fn: str = "silu",
189
+ norm_num_groups: Optional[int] = 32,
190
+ norm_eps: float = 1e-5,
191
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
+ encoder_hid_dim: Optional[int] = None,
195
+ encoder_hid_dim_type: Optional[str] = None,
196
+ attention_head_dim: Union[int, Tuple[int]] = 8,
197
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
+ dual_cross_attention: bool = False,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ resnet_skip_time_act: bool = False,
207
+ resnet_out_scale_factor: int = 1.0,
208
+ time_embedding_type: str = "positional",
209
+ time_embedding_dim: Optional[int] = None,
210
+ time_embedding_act_fn: Optional[str] = None,
211
+ timestep_post_act: Optional[str] = None,
212
+ time_cond_proj_dim: Optional[int] = None,
213
+ conv_in_kernel: int = 3,
214
+ conv_out_kernel: int = 3,
215
+ projection_class_embeddings_input_dim: Optional[int] = None,
216
+ attention_type: str = "default",
217
+ class_embeddings_concat: bool = False,
218
+ mid_block_only_cross_attention: Optional[bool] = None,
219
+ cross_attention_norm: Optional[str] = None,
220
+ addition_embed_type_num_heads=64,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.sample_size = sample_size
225
+
226
+ if num_attention_heads is not None:
227
+ raise ValueError(
228
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
229
+ )
230
+
231
+ # If `num_attention_heads` is not defined (which is the case for most models)
232
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
233
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
234
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
235
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
236
+ # which is why we correct for the naming here.
237
+ num_attention_heads = num_attention_heads or attention_head_dim
238
+
239
+ # Check inputs
240
+ if len(down_block_types) != len(up_block_types):
241
+ raise ValueError(
242
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
243
+ )
244
+
245
+ if len(block_out_channels) != len(down_block_types):
246
+ raise ValueError(
247
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
248
+ )
249
+
250
+ if not isinstance(only_cross_attention, bool) and len(
251
+ only_cross_attention
252
+ ) != len(down_block_types):
253
+ raise ValueError(
254
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
255
+ )
256
+
257
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
258
+ down_block_types
259
+ ):
260
+ raise ValueError(
261
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
262
+ )
263
+
264
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
265
+ down_block_types
266
+ ):
267
+ raise ValueError(
268
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
269
+ )
270
+
271
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
272
+ down_block_types
273
+ ):
274
+ raise ValueError(
275
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
276
+ )
277
+
278
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
279
+ down_block_types
280
+ ):
281
+ raise ValueError(
282
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
283
+ )
284
+ if (
285
+ isinstance(transformer_layers_per_block, list)
286
+ and reverse_transformer_layers_per_block is None
287
+ ):
288
+ for layer_number_per_block in transformer_layers_per_block:
289
+ if isinstance(layer_number_per_block, list):
290
+ raise ValueError(
291
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
292
+ )
293
+
294
+ # input
295
+ conv_in_padding = (conv_in_kernel - 1) // 2
296
+ self.conv_in = nn.Conv2d(
297
+ in_channels,
298
+ block_out_channels[0],
299
+ kernel_size=conv_in_kernel,
300
+ padding=conv_in_padding,
301
+ )
302
+
303
+ # time
304
+ if time_embedding_type == "fourier":
305
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
306
+ if time_embed_dim % 2 != 0:
307
+ raise ValueError(
308
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
309
+ )
310
+ self.time_proj = GaussianFourierProjection(
311
+ time_embed_dim // 2,
312
+ set_W_to_weight=False,
313
+ log=False,
314
+ flip_sin_to_cos=flip_sin_to_cos,
315
+ )
316
+ timestep_input_dim = time_embed_dim
317
+ elif time_embedding_type == "positional":
318
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
319
+
320
+ self.time_proj = Timesteps(
321
+ block_out_channels[0], flip_sin_to_cos, freq_shift
322
+ )
323
+ timestep_input_dim = block_out_channels[0]
324
+ else:
325
+ raise ValueError(
326
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
327
+ )
328
+
329
+ self.time_embedding = TimestepEmbedding(
330
+ timestep_input_dim,
331
+ time_embed_dim,
332
+ act_fn=act_fn,
333
+ post_act_fn=timestep_post_act,
334
+ cond_proj_dim=time_cond_proj_dim,
335
+ )
336
+
337
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
338
+ encoder_hid_dim_type = "text_proj"
339
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
340
+ logger.info(
341
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
342
+ )
343
+
344
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
345
+ raise ValueError(
346
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
347
+ )
348
+
349
+ if encoder_hid_dim_type == "text_proj":
350
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
351
+ elif encoder_hid_dim_type == "text_image_proj":
352
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
353
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
354
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
355
+ self.encoder_hid_proj = TextImageProjection(
356
+ text_embed_dim=encoder_hid_dim,
357
+ image_embed_dim=cross_attention_dim,
358
+ cross_attention_dim=cross_attention_dim,
359
+ )
360
+ elif encoder_hid_dim_type == "image_proj":
361
+ # Kandinsky 2.2
362
+ self.encoder_hid_proj = ImageProjection(
363
+ image_embed_dim=encoder_hid_dim,
364
+ cross_attention_dim=cross_attention_dim,
365
+ )
366
+ elif encoder_hid_dim_type is not None:
367
+ raise ValueError(
368
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
369
+ )
370
+ else:
371
+ self.encoder_hid_proj = None
372
+
373
+ # class embedding
374
+ if class_embed_type is None and num_class_embeds is not None:
375
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
376
+ elif class_embed_type == "timestep":
377
+ self.class_embedding = TimestepEmbedding(
378
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
379
+ )
380
+ elif class_embed_type == "identity":
381
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
382
+ elif class_embed_type == "projection":
383
+ if projection_class_embeddings_input_dim is None:
384
+ raise ValueError(
385
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
386
+ )
387
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
388
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
389
+ # 2. it projects from an arbitrary input dimension.
390
+ #
391
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
392
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
393
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
394
+ self.class_embedding = TimestepEmbedding(
395
+ projection_class_embeddings_input_dim, time_embed_dim
396
+ )
397
+ elif class_embed_type == "simple_projection":
398
+ if projection_class_embeddings_input_dim is None:
399
+ raise ValueError(
400
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
401
+ )
402
+ self.class_embedding = nn.Linear(
403
+ projection_class_embeddings_input_dim, time_embed_dim
404
+ )
405
+ else:
406
+ self.class_embedding = None
407
+
408
+ if addition_embed_type == "text":
409
+ if encoder_hid_dim is not None:
410
+ text_time_embedding_from_dim = encoder_hid_dim
411
+ else:
412
+ text_time_embedding_from_dim = cross_attention_dim
413
+
414
+ self.add_embedding = TextTimeEmbedding(
415
+ text_time_embedding_from_dim,
416
+ time_embed_dim,
417
+ num_heads=addition_embed_type_num_heads,
418
+ )
419
+ elif addition_embed_type == "text_image":
420
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
421
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
422
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
423
+ self.add_embedding = TextImageTimeEmbedding(
424
+ text_embed_dim=cross_attention_dim,
425
+ image_embed_dim=cross_attention_dim,
426
+ time_embed_dim=time_embed_dim,
427
+ )
428
+ elif addition_embed_type == "text_time":
429
+ self.add_time_proj = Timesteps(
430
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
431
+ )
432
+ self.add_embedding = TimestepEmbedding(
433
+ projection_class_embeddings_input_dim, time_embed_dim
434
+ )
435
+ elif addition_embed_type == "image":
436
+ # Kandinsky 2.2
437
+ self.add_embedding = ImageTimeEmbedding(
438
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
439
+ )
440
+ elif addition_embed_type == "image_hint":
441
+ # Kandinsky 2.2 ControlNet
442
+ self.add_embedding = ImageHintTimeEmbedding(
443
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
444
+ )
445
+ elif addition_embed_type is not None:
446
+ raise ValueError(
447
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
448
+ )
449
+
450
+ if time_embedding_act_fn is None:
451
+ self.time_embed_act = None
452
+ else:
453
+ self.time_embed_act = get_activation(time_embedding_act_fn)
454
+
455
+ self.down_blocks = nn.ModuleList([])
456
+ self.up_blocks = nn.ModuleList([])
457
+
458
+ if isinstance(only_cross_attention, bool):
459
+ if mid_block_only_cross_attention is None:
460
+ mid_block_only_cross_attention = only_cross_attention
461
+
462
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
463
+
464
+ if mid_block_only_cross_attention is None:
465
+ mid_block_only_cross_attention = False
466
+
467
+ if isinstance(num_attention_heads, int):
468
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
469
+
470
+ if isinstance(attention_head_dim, int):
471
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
472
+
473
+ if isinstance(cross_attention_dim, int):
474
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
475
+
476
+ if isinstance(layers_per_block, int):
477
+ layers_per_block = [layers_per_block] * len(down_block_types)
478
+
479
+ if isinstance(transformer_layers_per_block, int):
480
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
481
+ down_block_types
482
+ )
483
+
484
+ if class_embeddings_concat:
485
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
486
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
487
+ # regular time embeddings
488
+ blocks_time_embed_dim = time_embed_dim * 2
489
+ else:
490
+ blocks_time_embed_dim = time_embed_dim
491
+
492
+ # down
493
+ output_channel = block_out_channels[0]
494
+ for i, down_block_type in enumerate(down_block_types):
495
+ input_channel = output_channel
496
+ output_channel = block_out_channels[i]
497
+ is_final_block = i == len(block_out_channels) - 1
498
+
499
+ down_block = get_down_block(
500
+ down_block_type,
501
+ num_layers=layers_per_block[i],
502
+ transformer_layers_per_block=transformer_layers_per_block[i],
503
+ in_channels=input_channel,
504
+ out_channels=output_channel,
505
+ temb_channels=blocks_time_embed_dim,
506
+ add_downsample=not is_final_block,
507
+ resnet_eps=norm_eps,
508
+ resnet_act_fn=act_fn,
509
+ resnet_groups=norm_num_groups,
510
+ cross_attention_dim=cross_attention_dim[i],
511
+ num_attention_heads=num_attention_heads[i],
512
+ downsample_padding=downsample_padding,
513
+ dual_cross_attention=dual_cross_attention,
514
+ use_linear_projection=use_linear_projection,
515
+ only_cross_attention=only_cross_attention[i],
516
+ upcast_attention=upcast_attention,
517
+ resnet_time_scale_shift=resnet_time_scale_shift,
518
+ attention_type=attention_type,
519
+ resnet_skip_time_act=resnet_skip_time_act,
520
+ resnet_out_scale_factor=resnet_out_scale_factor,
521
+ cross_attention_norm=cross_attention_norm,
522
+ attention_head_dim=attention_head_dim[i]
523
+ if attention_head_dim[i] is not None
524
+ else output_channel,
525
+ dropout=dropout,
526
+ )
527
+ self.down_blocks.append(down_block)
528
+
529
+ # mid
530
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
531
+ self.mid_block = UNetMidBlock2DCrossAttn(
532
+ transformer_layers_per_block=transformer_layers_per_block[-1],
533
+ in_channels=block_out_channels[-1],
534
+ temb_channels=blocks_time_embed_dim,
535
+ dropout=dropout,
536
+ resnet_eps=norm_eps,
537
+ resnet_act_fn=act_fn,
538
+ output_scale_factor=mid_block_scale_factor,
539
+ resnet_time_scale_shift=resnet_time_scale_shift,
540
+ cross_attention_dim=cross_attention_dim[-1],
541
+ num_attention_heads=num_attention_heads[-1],
542
+ resnet_groups=norm_num_groups,
543
+ dual_cross_attention=dual_cross_attention,
544
+ use_linear_projection=use_linear_projection,
545
+ upcast_attention=upcast_attention,
546
+ attention_type=attention_type,
547
+ )
548
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
549
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
550
+ elif mid_block_type == "UNetMidBlock2D":
551
+ self.mid_block = UNetMidBlock2D(
552
+ in_channels=block_out_channels[-1],
553
+ temb_channels=blocks_time_embed_dim,
554
+ dropout=dropout,
555
+ num_layers=0,
556
+ resnet_eps=norm_eps,
557
+ resnet_act_fn=act_fn,
558
+ output_scale_factor=mid_block_scale_factor,
559
+ resnet_groups=norm_num_groups,
560
+ resnet_time_scale_shift=resnet_time_scale_shift,
561
+ add_attention=False,
562
+ )
563
+ elif mid_block_type is None:
564
+ self.mid_block = None
565
+ else:
566
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
567
+
568
+ # count how many layers upsample the images
569
+ self.num_upsamplers = 0
570
+
571
+ # up
572
+ reversed_block_out_channels = list(reversed(block_out_channels))
573
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
574
+ reversed_layers_per_block = list(reversed(layers_per_block))
575
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
576
+ reversed_transformer_layers_per_block = (
577
+ list(reversed(transformer_layers_per_block))
578
+ if reverse_transformer_layers_per_block is None
579
+ else reverse_transformer_layers_per_block
580
+ )
581
+ only_cross_attention = list(reversed(only_cross_attention))
582
+
583
+ output_channel = reversed_block_out_channels[0]
584
+ for i, up_block_type in enumerate(up_block_types):
585
+ is_final_block = i == len(block_out_channels) - 1
586
+
587
+ prev_output_channel = output_channel
588
+ output_channel = reversed_block_out_channels[i]
589
+ input_channel = reversed_block_out_channels[
590
+ min(i + 1, len(block_out_channels) - 1)
591
+ ]
592
+
593
+ # add upsample block for all BUT final layer
594
+ if not is_final_block:
595
+ add_upsample = True
596
+ self.num_upsamplers += 1
597
+ else:
598
+ add_upsample = False
599
+
600
+ up_block = get_up_block(
601
+ up_block_type,
602
+ num_layers=reversed_layers_per_block[i] + 1,
603
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
604
+ in_channels=input_channel,
605
+ out_channels=output_channel,
606
+ prev_output_channel=prev_output_channel,
607
+ temb_channels=blocks_time_embed_dim,
608
+ add_upsample=add_upsample,
609
+ resnet_eps=norm_eps,
610
+ resnet_act_fn=act_fn,
611
+ resolution_idx=i,
612
+ resnet_groups=norm_num_groups,
613
+ cross_attention_dim=reversed_cross_attention_dim[i],
614
+ num_attention_heads=reversed_num_attention_heads[i],
615
+ dual_cross_attention=dual_cross_attention,
616
+ use_linear_projection=use_linear_projection,
617
+ only_cross_attention=only_cross_attention[i],
618
+ upcast_attention=upcast_attention,
619
+ resnet_time_scale_shift=resnet_time_scale_shift,
620
+ attention_type=attention_type,
621
+ resnet_skip_time_act=resnet_skip_time_act,
622
+ resnet_out_scale_factor=resnet_out_scale_factor,
623
+ cross_attention_norm=cross_attention_norm,
624
+ attention_head_dim=attention_head_dim[i]
625
+ if attention_head_dim[i] is not None
626
+ else output_channel,
627
+ dropout=dropout,
628
+ )
629
+ self.up_blocks.append(up_block)
630
+ prev_output_channel = output_channel
631
+
632
+ # out
633
+ if norm_num_groups is not None:
634
+ self.conv_norm_out = nn.GroupNorm(
635
+ num_channels=block_out_channels[0],
636
+ num_groups=norm_num_groups,
637
+ eps=norm_eps,
638
+ )
639
+
640
+ self.conv_act = get_activation(act_fn)
641
+
642
+ else:
643
+ self.conv_norm_out = None
644
+ self.conv_act = None
645
+ self.conv_norm_out = None
646
+
647
+ conv_out_padding = (conv_out_kernel - 1) // 2
648
+ # self.conv_out = nn.Conv2d(
649
+ # block_out_channels[0],
650
+ # out_channels,
651
+ # kernel_size=conv_out_kernel,
652
+ # padding=conv_out_padding,
653
+ # )
654
+
655
+ if attention_type in ["gated", "gated-text-image"]:
656
+ positive_len = 768
657
+ if isinstance(cross_attention_dim, int):
658
+ positive_len = cross_attention_dim
659
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
660
+ cross_attention_dim, list
661
+ ):
662
+ positive_len = cross_attention_dim[0]
663
+
664
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
665
+ self.position_net = PositionNet(
666
+ positive_len=positive_len,
667
+ out_dim=cross_attention_dim,
668
+ feature_type=feature_type,
669
+ )
670
+
671
+ @property
672
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
673
+ r"""
674
+ Returns:
675
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
676
+ indexed by its weight name.
677
+ """
678
+ # set recursively
679
+ processors = {}
680
+
681
+ def fn_recursive_add_processors(
682
+ name: str,
683
+ module: torch.nn.Module,
684
+ processors: Dict[str, AttentionProcessor],
685
+ ):
686
+ if hasattr(module, "get_processor"):
687
+ processors[f"{name}.processor"] = module.get_processor(
688
+ return_deprecated_lora=True
689
+ )
690
+
691
+ for sub_name, child in module.named_children():
692
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
693
+
694
+ return processors
695
+
696
+ for name, module in self.named_children():
697
+ fn_recursive_add_processors(name, module, processors)
698
+
699
+ return processors
700
+
701
+ def set_attn_processor(
702
+ self,
703
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
704
+ _remove_lora=False,
705
+ ):
706
+ r"""
707
+ Sets the attention processor to use to compute attention.
708
+
709
+ Parameters:
710
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
711
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
712
+ for **all** `Attention` layers.
713
+
714
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
715
+ processor. This is strongly recommended when setting trainable attention processors.
716
+
717
+ """
718
+ count = len(self.attn_processors.keys())
719
+
720
+ if isinstance(processor, dict) and len(processor) != count:
721
+ raise ValueError(
722
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
723
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
724
+ )
725
+
726
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
727
+ if hasattr(module, "set_processor"):
728
+ if not isinstance(processor, dict):
729
+ module.set_processor(processor, _remove_lora=_remove_lora)
730
+ else:
731
+ module.set_processor(
732
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
733
+ )
734
+
735
+ for sub_name, child in module.named_children():
736
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
737
+
738
+ for name, module in self.named_children():
739
+ fn_recursive_attn_processor(name, module, processor)
740
+
741
+ def set_default_attn_processor(self):
742
+ """
743
+ Disables custom attention processors and sets the default attention implementation.
744
+ """
745
+ if all(
746
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
747
+ for proc in self.attn_processors.values()
748
+ ):
749
+ processor = AttnAddedKVProcessor()
750
+ elif all(
751
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
752
+ for proc in self.attn_processors.values()
753
+ ):
754
+ processor = AttnProcessor()
755
+ else:
756
+ raise ValueError(
757
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
758
+ )
759
+
760
+ self.set_attn_processor(processor, _remove_lora=True)
761
+
762
+ def set_attention_slice(self, slice_size):
763
+ r"""
764
+ Enable sliced attention computation.
765
+
766
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
767
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
768
+
769
+ Args:
770
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
771
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
772
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
773
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
774
+ must be a multiple of `slice_size`.
775
+ """
776
+ sliceable_head_dims = []
777
+
778
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
779
+ if hasattr(module, "set_attention_slice"):
780
+ sliceable_head_dims.append(module.sliceable_head_dim)
781
+
782
+ for child in module.children():
783
+ fn_recursive_retrieve_sliceable_dims(child)
784
+
785
+ # retrieve number of attention layers
786
+ for module in self.children():
787
+ fn_recursive_retrieve_sliceable_dims(module)
788
+
789
+ num_sliceable_layers = len(sliceable_head_dims)
790
+
791
+ if slice_size == "auto":
792
+ # half the attention head size is usually a good trade-off between
793
+ # speed and memory
794
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
795
+ elif slice_size == "max":
796
+ # make smallest slice possible
797
+ slice_size = num_sliceable_layers * [1]
798
+
799
+ slice_size = (
800
+ num_sliceable_layers * [slice_size]
801
+ if not isinstance(slice_size, list)
802
+ else slice_size
803
+ )
804
+
805
+ if len(slice_size) != len(sliceable_head_dims):
806
+ raise ValueError(
807
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
808
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
809
+ )
810
+
811
+ for i in range(len(slice_size)):
812
+ size = slice_size[i]
813
+ dim = sliceable_head_dims[i]
814
+ if size is not None and size > dim:
815
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
816
+
817
+ # Recursively walk through all the children.
818
+ # Any children which exposes the set_attention_slice method
819
+ # gets the message
820
+ def fn_recursive_set_attention_slice(
821
+ module: torch.nn.Module, slice_size: List[int]
822
+ ):
823
+ if hasattr(module, "set_attention_slice"):
824
+ module.set_attention_slice(slice_size.pop())
825
+
826
+ for child in module.children():
827
+ fn_recursive_set_attention_slice(child, slice_size)
828
+
829
+ reversed_slice_size = list(reversed(slice_size))
830
+ for module in self.children():
831
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
832
+
833
+ def _set_gradient_checkpointing(self, module, value=False):
834
+ if hasattr(module, "gradient_checkpointing"):
835
+ module.gradient_checkpointing = value
836
+
837
+ def enable_freeu(self, s1, s2, b1, b2):
838
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
839
+
840
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
841
+
842
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
843
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
844
+
845
+ Args:
846
+ s1 (`float`):
847
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
848
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
849
+ s2 (`float`):
850
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
851
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
852
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
853
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
854
+ """
855
+ for i, upsample_block in enumerate(self.up_blocks):
856
+ setattr(upsample_block, "s1", s1)
857
+ setattr(upsample_block, "s2", s2)
858
+ setattr(upsample_block, "b1", b1)
859
+ setattr(upsample_block, "b2", b2)
860
+
861
+ def disable_freeu(self):
862
+ """Disables the FreeU mechanism."""
863
+ freeu_keys = {"s1", "s2", "b1", "b2"}
864
+ for i, upsample_block in enumerate(self.up_blocks):
865
+ for k in freeu_keys:
866
+ if (
867
+ hasattr(upsample_block, k)
868
+ or getattr(upsample_block, k, None) is not None
869
+ ):
870
+ setattr(upsample_block, k, None)
871
+
872
+ def forward(
873
+ self,
874
+ sample: torch.FloatTensor,
875
+ timestep: Union[torch.Tensor, float, int],
876
+ encoder_hidden_states: torch.Tensor,
877
+ class_labels: Optional[torch.Tensor] = None,
878
+ timestep_cond: Optional[torch.Tensor] = None,
879
+ attention_mask: Optional[torch.Tensor] = None,
880
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
881
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
882
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
883
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
884
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
885
+ encoder_attention_mask: Optional[torch.Tensor] = None,
886
+ return_dict: bool = True,
887
+ ) -> Union[UNet2DConditionOutput, Tuple]:
888
+ r"""
889
+ The [`UNet2DConditionModel`] forward method.
890
+
891
+ Args:
892
+ sample (`torch.FloatTensor`):
893
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
894
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
895
+ encoder_hidden_states (`torch.FloatTensor`):
896
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
897
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
898
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
899
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
900
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
901
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
902
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
903
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
904
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
905
+ negative values to the attention scores corresponding to "discard" tokens.
906
+ cross_attention_kwargs (`dict`, *optional*):
907
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
908
+ `self.processor` in
909
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
910
+ added_cond_kwargs: (`dict`, *optional*):
911
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
912
+ are passed along to the UNet blocks.
913
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
914
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
915
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
916
+ A tensor that if specified is added to the residual of the middle unet block.
917
+ encoder_attention_mask (`torch.Tensor`):
918
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
919
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
920
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
921
+ return_dict (`bool`, *optional*, defaults to `True`):
922
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
923
+ tuple.
924
+ cross_attention_kwargs (`dict`, *optional*):
925
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
926
+ added_cond_kwargs: (`dict`, *optional*):
927
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
928
+ are passed along to the UNet blocks.
929
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
930
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
931
+ example from ControlNet side model(s)
932
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
933
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
934
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
935
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
936
+
937
+ Returns:
938
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
939
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
940
+ a `tuple` is returned where the first element is the sample tensor.
941
+ """
942
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
943
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
944
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
945
+ # on the fly if necessary.
946
+ default_overall_up_factor = 2**self.num_upsamplers
947
+
948
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
949
+ forward_upsample_size = False
950
+ upsample_size = None
951
+
952
+ for dim in sample.shape[-2:]:
953
+ if dim % default_overall_up_factor != 0:
954
+ # Forward upsample size to force interpolation output size.
955
+ forward_upsample_size = True
956
+ break
957
+
958
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
959
+ # expects mask of shape:
960
+ # [batch, key_tokens]
961
+ # adds singleton query_tokens dimension:
962
+ # [batch, 1, key_tokens]
963
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
964
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
965
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
966
+ if attention_mask is not None:
967
+ # assume that mask is expressed as:
968
+ # (1 = keep, 0 = discard)
969
+ # convert mask into a bias that can be added to attention scores:
970
+ # (keep = +0, discard = -10000.0)
971
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
972
+ attention_mask = attention_mask.unsqueeze(1)
973
+
974
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
975
+ if encoder_attention_mask is not None:
976
+ encoder_attention_mask = (
977
+ 1 - encoder_attention_mask.to(sample.dtype)
978
+ ) * -10000.0
979
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
980
+
981
+ # 0. center input if necessary
982
+ if self.config.center_input_sample:
983
+ sample = 2 * sample - 1.0
984
+
985
+ # 1. time
986
+ timesteps = timestep
987
+ if not torch.is_tensor(timesteps):
988
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
989
+ # This would be a good case for the `match` statement (Python 3.10+)
990
+ is_mps = sample.device.type == "mps"
991
+ if isinstance(timestep, float):
992
+ dtype = torch.float32 if is_mps else torch.float64
993
+ else:
994
+ dtype = torch.int32 if is_mps else torch.int64
995
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
996
+ elif len(timesteps.shape) == 0:
997
+ timesteps = timesteps[None].to(sample.device)
998
+
999
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1000
+ timesteps = timesteps.expand(sample.shape[0])
1001
+
1002
+ t_emb = self.time_proj(timesteps)
1003
+
1004
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1005
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1006
+ # there might be better ways to encapsulate this.
1007
+ t_emb = t_emb.to(dtype=sample.dtype)
1008
+
1009
+ emb = self.time_embedding(t_emb, timestep_cond)
1010
+ aug_emb = None
1011
+
1012
+ if self.class_embedding is not None:
1013
+ if class_labels is None:
1014
+ raise ValueError(
1015
+ "class_labels should be provided when num_class_embeds > 0"
1016
+ )
1017
+
1018
+ if self.config.class_embed_type == "timestep":
1019
+ class_labels = self.time_proj(class_labels)
1020
+
1021
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1022
+ # there might be better ways to encapsulate this.
1023
+ class_labels = class_labels.to(dtype=sample.dtype)
1024
+
1025
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1026
+
1027
+ if self.config.class_embeddings_concat:
1028
+ emb = torch.cat([emb, class_emb], dim=-1)
1029
+ else:
1030
+ emb = emb + class_emb
1031
+
1032
+ if self.config.addition_embed_type == "text":
1033
+ aug_emb = self.add_embedding(encoder_hidden_states)
1034
+ elif self.config.addition_embed_type == "text_image":
1035
+ # Kandinsky 2.1 - style
1036
+ if "image_embeds" not in added_cond_kwargs:
1037
+ raise ValueError(
1038
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1039
+ )
1040
+
1041
+ image_embs = added_cond_kwargs.get("image_embeds")
1042
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1043
+ aug_emb = self.add_embedding(text_embs, image_embs)
1044
+ elif self.config.addition_embed_type == "text_time":
1045
+ # SDXL - style
1046
+ if "text_embeds" not in added_cond_kwargs:
1047
+ raise ValueError(
1048
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1049
+ )
1050
+ text_embeds = added_cond_kwargs.get("text_embeds")
1051
+ if "time_ids" not in added_cond_kwargs:
1052
+ raise ValueError(
1053
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1054
+ )
1055
+ time_ids = added_cond_kwargs.get("time_ids")
1056
+ time_embeds = self.add_time_proj(time_ids.flatten())
1057
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1058
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1059
+ add_embeds = add_embeds.to(emb.dtype)
1060
+ aug_emb = self.add_embedding(add_embeds)
1061
+ elif self.config.addition_embed_type == "image":
1062
+ # Kandinsky 2.2 - style
1063
+ if "image_embeds" not in added_cond_kwargs:
1064
+ raise ValueError(
1065
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1066
+ )
1067
+ image_embs = added_cond_kwargs.get("image_embeds")
1068
+ aug_emb = self.add_embedding(image_embs)
1069
+ elif self.config.addition_embed_type == "image_hint":
1070
+ # Kandinsky 2.2 - style
1071
+ if (
1072
+ "image_embeds" not in added_cond_kwargs
1073
+ or "hint" not in added_cond_kwargs
1074
+ ):
1075
+ raise ValueError(
1076
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1077
+ )
1078
+ image_embs = added_cond_kwargs.get("image_embeds")
1079
+ hint = added_cond_kwargs.get("hint")
1080
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1081
+ sample = torch.cat([sample, hint], dim=1)
1082
+
1083
+ emb = emb + aug_emb if aug_emb is not None else emb
1084
+
1085
+ if self.time_embed_act is not None:
1086
+ emb = self.time_embed_act(emb)
1087
+
1088
+ if (
1089
+ self.encoder_hid_proj is not None
1090
+ and self.config.encoder_hid_dim_type == "text_proj"
1091
+ ):
1092
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1093
+ elif (
1094
+ self.encoder_hid_proj is not None
1095
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1096
+ ):
1097
+ # Kadinsky 2.1 - style
1098
+ if "image_embeds" not in added_cond_kwargs:
1099
+ raise ValueError(
1100
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1101
+ )
1102
+
1103
+ image_embeds = added_cond_kwargs.get("image_embeds")
1104
+ encoder_hidden_states = self.encoder_hid_proj(
1105
+ encoder_hidden_states, image_embeds
1106
+ )
1107
+ elif (
1108
+ self.encoder_hid_proj is not None
1109
+ and self.config.encoder_hid_dim_type == "image_proj"
1110
+ ):
1111
+ # Kandinsky 2.2 - style
1112
+ if "image_embeds" not in added_cond_kwargs:
1113
+ raise ValueError(
1114
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1115
+ )
1116
+ image_embeds = added_cond_kwargs.get("image_embeds")
1117
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1118
+ elif (
1119
+ self.encoder_hid_proj is not None
1120
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1121
+ ):
1122
+ if "image_embeds" not in added_cond_kwargs:
1123
+ raise ValueError(
1124
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1125
+ )
1126
+ image_embeds = added_cond_kwargs.get("image_embeds")
1127
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1128
+ encoder_hidden_states.dtype
1129
+ )
1130
+ encoder_hidden_states = torch.cat(
1131
+ [encoder_hidden_states, image_embeds], dim=1
1132
+ )
1133
+
1134
+ # 2. pre-process
1135
+ sample = self.conv_in(sample)
1136
+
1137
+ # 2.5 GLIGEN position net
1138
+ if (
1139
+ cross_attention_kwargs is not None
1140
+ and cross_attention_kwargs.get("gligen", None) is not None
1141
+ ):
1142
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1143
+ gligen_args = cross_attention_kwargs.pop("gligen")
1144
+ cross_attention_kwargs["gligen"] = {
1145
+ "objs": self.position_net(**gligen_args)
1146
+ }
1147
+
1148
+ # 3. down
1149
+ lora_scale = (
1150
+ cross_attention_kwargs.get("scale", 1.0)
1151
+ if cross_attention_kwargs is not None
1152
+ else 1.0
1153
+ )
1154
+ if USE_PEFT_BACKEND:
1155
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1156
+ scale_lora_layers(self, lora_scale)
1157
+
1158
+ is_controlnet = (
1159
+ mid_block_additional_residual is not None
1160
+ and down_block_additional_residuals is not None
1161
+ )
1162
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1163
+ is_adapter = down_intrablock_additional_residuals is not None
1164
+ # maintain backward compatibility for legacy usage, where
1165
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1166
+ # but can only use one or the other
1167
+ if (
1168
+ not is_adapter
1169
+ and mid_block_additional_residual is None
1170
+ and down_block_additional_residuals is not None
1171
+ ):
1172
+ deprecate(
1173
+ "T2I should not use down_block_additional_residuals",
1174
+ "1.3.0",
1175
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1176
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1177
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1178
+ standard_warn=False,
1179
+ )
1180
+ down_intrablock_additional_residuals = down_block_additional_residuals
1181
+ is_adapter = True
1182
+
1183
+ down_block_res_samples = (sample,)
1184
+ tot_referece_features = ()
1185
+ for downsample_block in self.down_blocks:
1186
+ if (
1187
+ hasattr(downsample_block, "has_cross_attention")
1188
+ and downsample_block.has_cross_attention
1189
+ ):
1190
+ # For t2i-adapter CrossAttnDownBlock2D
1191
+ additional_residuals = {}
1192
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1193
+ additional_residuals[
1194
+ "additional_residuals"
1195
+ ] = down_intrablock_additional_residuals.pop(0)
1196
+
1197
+ sample, res_samples = downsample_block(
1198
+ hidden_states=sample,
1199
+ temb=emb,
1200
+ encoder_hidden_states=encoder_hidden_states,
1201
+ attention_mask=attention_mask,
1202
+ cross_attention_kwargs=cross_attention_kwargs,
1203
+ encoder_attention_mask=encoder_attention_mask,
1204
+ **additional_residuals,
1205
+ )
1206
+ else:
1207
+ sample, res_samples = downsample_block(
1208
+ hidden_states=sample, temb=emb, scale=lora_scale
1209
+ )
1210
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1211
+ sample += down_intrablock_additional_residuals.pop(0)
1212
+
1213
+ down_block_res_samples += res_samples
1214
+
1215
+ if is_controlnet:
1216
+ new_down_block_res_samples = ()
1217
+
1218
+ for down_block_res_sample, down_block_additional_residual in zip(
1219
+ down_block_res_samples, down_block_additional_residuals
1220
+ ):
1221
+ down_block_res_sample = (
1222
+ down_block_res_sample + down_block_additional_residual
1223
+ )
1224
+ new_down_block_res_samples = new_down_block_res_samples + (
1225
+ down_block_res_sample,
1226
+ )
1227
+
1228
+ down_block_res_samples = new_down_block_res_samples
1229
+
1230
+ # 4. mid
1231
+ if self.mid_block is not None:
1232
+ if (
1233
+ hasattr(self.mid_block, "has_cross_attention")
1234
+ and self.mid_block.has_cross_attention
1235
+ ):
1236
+ sample = self.mid_block(
1237
+ sample,
1238
+ emb,
1239
+ encoder_hidden_states=encoder_hidden_states,
1240
+ attention_mask=attention_mask,
1241
+ cross_attention_kwargs=cross_attention_kwargs,
1242
+ encoder_attention_mask=encoder_attention_mask,
1243
+ )
1244
+ else:
1245
+ sample = self.mid_block(sample, emb)
1246
+
1247
+ # To support T2I-Adapter-XL
1248
+ if (
1249
+ is_adapter
1250
+ and len(down_intrablock_additional_residuals) > 0
1251
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1252
+ ):
1253
+ sample += down_intrablock_additional_residuals.pop(0)
1254
+
1255
+ if is_controlnet:
1256
+ sample = sample + mid_block_additional_residual
1257
+
1258
+ # 5. up
1259
+ for i, upsample_block in enumerate(self.up_blocks):
1260
+ is_final_block = i == len(self.up_blocks) - 1
1261
+
1262
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1263
+ down_block_res_samples = down_block_res_samples[
1264
+ : -len(upsample_block.resnets)
1265
+ ]
1266
+
1267
+ # if we have not reached the final block and need to forward the
1268
+ # upsample size, we do it here
1269
+ if not is_final_block and forward_upsample_size:
1270
+ upsample_size = down_block_res_samples[-1].shape[2:]
1271
+
1272
+ if (
1273
+ hasattr(upsample_block, "has_cross_attention")
1274
+ and upsample_block.has_cross_attention
1275
+ ):
1276
+ sample = upsample_block(
1277
+ hidden_states=sample,
1278
+ temb=emb,
1279
+ res_hidden_states_tuple=res_samples,
1280
+ encoder_hidden_states=encoder_hidden_states,
1281
+ cross_attention_kwargs=cross_attention_kwargs,
1282
+ upsample_size=upsample_size,
1283
+ attention_mask=attention_mask,
1284
+ encoder_attention_mask=encoder_attention_mask,
1285
+ )
1286
+ else:
1287
+ sample = upsample_block(
1288
+ hidden_states=sample,
1289
+ temb=emb,
1290
+ res_hidden_states_tuple=res_samples,
1291
+ upsample_size=upsample_size,
1292
+ scale=lora_scale,
1293
+ )
1294
+
1295
+ # 6. post-process
1296
+ # if self.conv_norm_out:
1297
+ # sample = self.conv_norm_out(sample)
1298
+ # sample = self.conv_act(sample)
1299
+ # sample = self.conv_out(sample)
1300
+
1301
+ if USE_PEFT_BACKEND:
1302
+ # remove `lora_scale` from each PEFT layer
1303
+ unscale_lora_layers(self, lora_scale)
1304
+
1305
+ if not return_dict:
1306
+ return (sample,)
1307
+
1308
+ return UNet2DConditionOutput(sample=sample)
src/models/unet_3d.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.attention_processor import AttentionProcessor
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
17
+ from safetensors.torch import load_file
18
+
19
+ from .resnet import InflatedConv3d, InflatedGroupNorm
20
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+
25
+ @dataclass
26
+ class UNet3DConditionOutput(BaseOutput):
27
+ sample: torch.FloatTensor
28
+
29
+
30
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
31
+ _supports_gradient_checkpointing = True
32
+
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ sample_size: Optional[int] = None,
37
+ in_channels: int = 4,
38
+ out_channels: int = 4,
39
+ center_input_sample: bool = False,
40
+ flip_sin_to_cos: bool = True,
41
+ freq_shift: int = 0,
42
+ down_block_types: Tuple[str] = (
43
+ "CrossAttnDownBlock3D",
44
+ "CrossAttnDownBlock3D",
45
+ "CrossAttnDownBlock3D",
46
+ "DownBlock3D",
47
+ ),
48
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
49
+ up_block_types: Tuple[str] = (
50
+ "UpBlock3D",
51
+ "CrossAttnUpBlock3D",
52
+ "CrossAttnUpBlock3D",
53
+ "CrossAttnUpBlock3D",
54
+ ),
55
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
56
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
57
+ layers_per_block: int = 2,
58
+ downsample_padding: int = 1,
59
+ mid_block_scale_factor: float = 1,
60
+ act_fn: str = "silu",
61
+ norm_num_groups: int = 32,
62
+ norm_eps: float = 1e-5,
63
+ cross_attention_dim: int = 1280,
64
+ attention_head_dim: Union[int, Tuple[int]] = 8,
65
+ dual_cross_attention: bool = False,
66
+ use_linear_projection: bool = False,
67
+ class_embed_type: Optional[str] = None,
68
+ num_class_embeds: Optional[int] = None,
69
+ upcast_attention: bool = False,
70
+ resnet_time_scale_shift: str = "default",
71
+ use_inflated_groupnorm=False,
72
+ # Additional
73
+ use_motion_module=False,
74
+ motion_module_resolutions=(1, 2, 4, 8),
75
+ motion_module_mid_block=False,
76
+ motion_module_decoder_only=False,
77
+ motion_module_type=None,
78
+ motion_module_kwargs={},
79
+ unet_use_cross_frame_attention=None,
80
+ unet_use_temporal_attention=None,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.sample_size = sample_size
85
+ time_embed_dim = block_out_channels[0] * 4
86
+
87
+ # input
88
+ self.conv_in = InflatedConv3d(
89
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
90
+ )
91
+
92
+ # time
93
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
94
+ timestep_input_dim = block_out_channels[0]
95
+
96
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
97
+
98
+ # class embedding
99
+ if class_embed_type is None and num_class_embeds is not None:
100
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
101
+ elif class_embed_type == "timestep":
102
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
103
+ elif class_embed_type == "identity":
104
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
105
+ else:
106
+ self.class_embedding = None
107
+
108
+ self.down_blocks = nn.ModuleList([])
109
+ self.mid_block = None
110
+ self.up_blocks = nn.ModuleList([])
111
+
112
+ if isinstance(only_cross_attention, bool):
113
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
114
+
115
+ if isinstance(attention_head_dim, int):
116
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
117
+
118
+ # down
119
+ output_channel = block_out_channels[0]
120
+ for i, down_block_type in enumerate(down_block_types):
121
+ res = 2**i
122
+ input_channel = output_channel
123
+ output_channel = block_out_channels[i]
124
+ is_final_block = i == len(block_out_channels) - 1
125
+
126
+ down_block = get_down_block(
127
+ down_block_type,
128
+ num_layers=layers_per_block,
129
+ in_channels=input_channel,
130
+ out_channels=output_channel,
131
+ temb_channels=time_embed_dim,
132
+ add_downsample=not is_final_block,
133
+ resnet_eps=norm_eps,
134
+ resnet_act_fn=act_fn,
135
+ resnet_groups=norm_num_groups,
136
+ cross_attention_dim=cross_attention_dim,
137
+ attn_num_head_channels=attention_head_dim[i],
138
+ downsample_padding=downsample_padding,
139
+ dual_cross_attention=dual_cross_attention,
140
+ use_linear_projection=use_linear_projection,
141
+ only_cross_attention=only_cross_attention[i],
142
+ upcast_attention=upcast_attention,
143
+ resnet_time_scale_shift=resnet_time_scale_shift,
144
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
145
+ unet_use_temporal_attention=unet_use_temporal_attention,
146
+ use_inflated_groupnorm=use_inflated_groupnorm,
147
+ use_motion_module=use_motion_module
148
+ and (res in motion_module_resolutions)
149
+ and (not motion_module_decoder_only),
150
+ motion_module_type=motion_module_type,
151
+ motion_module_kwargs=motion_module_kwargs,
152
+ )
153
+ self.down_blocks.append(down_block)
154
+
155
+ # mid
156
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
157
+ self.mid_block = UNetMidBlock3DCrossAttn(
158
+ in_channels=block_out_channels[-1],
159
+ temb_channels=time_embed_dim,
160
+ resnet_eps=norm_eps,
161
+ resnet_act_fn=act_fn,
162
+ output_scale_factor=mid_block_scale_factor,
163
+ resnet_time_scale_shift=resnet_time_scale_shift,
164
+ cross_attention_dim=cross_attention_dim,
165
+ attn_num_head_channels=attention_head_dim[-1],
166
+ resnet_groups=norm_num_groups,
167
+ dual_cross_attention=dual_cross_attention,
168
+ use_linear_projection=use_linear_projection,
169
+ upcast_attention=upcast_attention,
170
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
171
+ unet_use_temporal_attention=unet_use_temporal_attention,
172
+ use_inflated_groupnorm=use_inflated_groupnorm,
173
+ use_motion_module=use_motion_module and motion_module_mid_block,
174
+ motion_module_type=motion_module_type,
175
+ motion_module_kwargs=motion_module_kwargs,
176
+ )
177
+ else:
178
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
179
+
180
+ # count how many layers upsample the videos
181
+ self.num_upsamplers = 0
182
+
183
+ # up
184
+ reversed_block_out_channels = list(reversed(block_out_channels))
185
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
186
+ only_cross_attention = list(reversed(only_cross_attention))
187
+ output_channel = reversed_block_out_channels[0]
188
+ for i, up_block_type in enumerate(up_block_types):
189
+ res = 2 ** (3 - i)
190
+ is_final_block = i == len(block_out_channels) - 1
191
+
192
+ prev_output_channel = output_channel
193
+ output_channel = reversed_block_out_channels[i]
194
+ input_channel = reversed_block_out_channels[
195
+ min(i + 1, len(block_out_channels) - 1)
196
+ ]
197
+
198
+ # add upsample block for all BUT final layer
199
+ if not is_final_block:
200
+ add_upsample = True
201
+ self.num_upsamplers += 1
202
+ else:
203
+ add_upsample = False
204
+
205
+ up_block = get_up_block(
206
+ up_block_type,
207
+ num_layers=layers_per_block + 1,
208
+ in_channels=input_channel,
209
+ out_channels=output_channel,
210
+ prev_output_channel=prev_output_channel,
211
+ temb_channels=time_embed_dim,
212
+ add_upsample=add_upsample,
213
+ resnet_eps=norm_eps,
214
+ resnet_act_fn=act_fn,
215
+ resnet_groups=norm_num_groups,
216
+ cross_attention_dim=cross_attention_dim,
217
+ attn_num_head_channels=reversed_attention_head_dim[i],
218
+ dual_cross_attention=dual_cross_attention,
219
+ use_linear_projection=use_linear_projection,
220
+ only_cross_attention=only_cross_attention[i],
221
+ upcast_attention=upcast_attention,
222
+ resnet_time_scale_shift=resnet_time_scale_shift,
223
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
224
+ unet_use_temporal_attention=unet_use_temporal_attention,
225
+ use_inflated_groupnorm=use_inflated_groupnorm,
226
+ use_motion_module=use_motion_module
227
+ and (res in motion_module_resolutions),
228
+ motion_module_type=motion_module_type,
229
+ motion_module_kwargs=motion_module_kwargs,
230
+ )
231
+ self.up_blocks.append(up_block)
232
+ prev_output_channel = output_channel
233
+
234
+ # out
235
+ if use_inflated_groupnorm:
236
+ self.conv_norm_out = InflatedGroupNorm(
237
+ num_channels=block_out_channels[0],
238
+ num_groups=norm_num_groups,
239
+ eps=norm_eps,
240
+ )
241
+ else:
242
+ self.conv_norm_out = nn.GroupNorm(
243
+ num_channels=block_out_channels[0],
244
+ num_groups=norm_num_groups,
245
+ eps=norm_eps,
246
+ )
247
+ self.conv_act = nn.SiLU()
248
+ self.conv_out = InflatedConv3d(
249
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
250
+ )
251
+
252
+ @property
253
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
254
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
255
+ r"""
256
+ Returns:
257
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
258
+ indexed by its weight name.
259
+ """
260
+ # set recursively
261
+ processors = {}
262
+
263
+ def fn_recursive_add_processors(
264
+ name: str,
265
+ module: torch.nn.Module,
266
+ processors: Dict[str, AttentionProcessor],
267
+ ):
268
+ if hasattr(module, "set_processor"):
269
+ processors[f"{name}.processor"] = module.processor
270
+
271
+ for sub_name, child in module.named_children():
272
+ if "temporal_transformer" not in sub_name:
273
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
274
+
275
+ return processors
276
+
277
+ for name, module in self.named_children():
278
+ if "temporal_transformer" not in name:
279
+ fn_recursive_add_processors(name, module, processors)
280
+
281
+ return processors
282
+
283
+ def set_attention_slice(self, slice_size):
284
+ r"""
285
+ Enable sliced attention computation.
286
+
287
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
288
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
289
+
290
+ Args:
291
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
292
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
293
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
294
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
295
+ must be a multiple of `slice_size`.
296
+ """
297
+ sliceable_head_dims = []
298
+
299
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
300
+ if hasattr(module, "set_attention_slice"):
301
+ sliceable_head_dims.append(module.sliceable_head_dim)
302
+
303
+ for child in module.children():
304
+ fn_recursive_retrieve_slicable_dims(child)
305
+
306
+ # retrieve number of attention layers
307
+ for module in self.children():
308
+ fn_recursive_retrieve_slicable_dims(module)
309
+
310
+ num_slicable_layers = len(sliceable_head_dims)
311
+
312
+ if slice_size == "auto":
313
+ # half the attention head size is usually a good trade-off between
314
+ # speed and memory
315
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
316
+ elif slice_size == "max":
317
+ # make smallest slice possible
318
+ slice_size = num_slicable_layers * [1]
319
+
320
+ slice_size = (
321
+ num_slicable_layers * [slice_size]
322
+ if not isinstance(slice_size, list)
323
+ else slice_size
324
+ )
325
+
326
+ if len(slice_size) != len(sliceable_head_dims):
327
+ raise ValueError(
328
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
329
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
330
+ )
331
+
332
+ for i in range(len(slice_size)):
333
+ size = slice_size[i]
334
+ dim = sliceable_head_dims[i]
335
+ if size is not None and size > dim:
336
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
337
+
338
+ # Recursively walk through all the children.
339
+ # Any children which exposes the set_attention_slice method
340
+ # gets the message
341
+ def fn_recursive_set_attention_slice(
342
+ module: torch.nn.Module, slice_size: List[int]
343
+ ):
344
+ if hasattr(module, "set_attention_slice"):
345
+ module.set_attention_slice(slice_size.pop())
346
+
347
+ for child in module.children():
348
+ fn_recursive_set_attention_slice(child, slice_size)
349
+
350
+ reversed_slice_size = list(reversed(slice_size))
351
+ for module in self.children():
352
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
353
+
354
+ def _set_gradient_checkpointing(self, module, value=False):
355
+ if hasattr(module, "gradient_checkpointing"):
356
+ module.gradient_checkpointing = value
357
+
358
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
359
+ def set_attn_processor(
360
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
361
+ ):
362
+ r"""
363
+ Sets the attention processor to use to compute attention.
364
+
365
+ Parameters:
366
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
367
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
368
+ for **all** `Attention` layers.
369
+
370
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
371
+ processor. This is strongly recommended when setting trainable attention processors.
372
+
373
+ """
374
+ count = len(self.attn_processors.keys())
375
+
376
+ if isinstance(processor, dict) and len(processor) != count:
377
+ raise ValueError(
378
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
379
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
380
+ )
381
+
382
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
383
+ if hasattr(module, "set_processor"):
384
+ if not isinstance(processor, dict):
385
+ module.set_processor(processor)
386
+ else:
387
+ module.set_processor(processor.pop(f"{name}.processor"))
388
+
389
+ for sub_name, child in module.named_children():
390
+ if "temporal_transformer" not in sub_name:
391
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
392
+
393
+ for name, module in self.named_children():
394
+ if "temporal_transformer" not in name:
395
+ fn_recursive_attn_processor(name, module, processor)
396
+
397
+ def forward(
398
+ self,
399
+ sample: torch.FloatTensor,
400
+ timestep: Union[torch.Tensor, float, int],
401
+ encoder_hidden_states: torch.Tensor,
402
+ class_labels: Optional[torch.Tensor] = None,
403
+ pose_cond_fea: Optional[torch.Tensor] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
406
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
407
+ return_dict: bool = True,
408
+ ) -> Union[UNet3DConditionOutput, Tuple]:
409
+ r"""
410
+ Args:
411
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
412
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
413
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
414
+ return_dict (`bool`, *optional*, defaults to `True`):
415
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
416
+
417
+ Returns:
418
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
419
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
420
+ returning a tuple, the first element is the sample tensor.
421
+ """
422
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
423
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
424
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
425
+ # on the fly if necessary.
426
+ default_overall_up_factor = 2**self.num_upsamplers
427
+
428
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
429
+ forward_upsample_size = False
430
+ upsample_size = None
431
+
432
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
433
+ logger.info("Forward upsample size to force interpolation output size.")
434
+ forward_upsample_size = True
435
+
436
+ # prepare attention_mask
437
+ if attention_mask is not None:
438
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
439
+ attention_mask = attention_mask.unsqueeze(1)
440
+
441
+ # center input if necessary
442
+ if self.config.center_input_sample:
443
+ sample = 2 * sample - 1.0
444
+
445
+ # time
446
+ timesteps = timestep
447
+ if not torch.is_tensor(timesteps):
448
+ # This would be a good case for the `match` statement (Python 3.10+)
449
+ is_mps = sample.device.type == "mps"
450
+ if isinstance(timestep, float):
451
+ dtype = torch.float32 if is_mps else torch.float64
452
+ else:
453
+ dtype = torch.int32 if is_mps else torch.int64
454
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
455
+ elif len(timesteps.shape) == 0:
456
+ timesteps = timesteps[None].to(sample.device)
457
+
458
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
459
+ timesteps = timesteps.expand(sample.shape[0])
460
+
461
+ t_emb = self.time_proj(timesteps)
462
+
463
+ # timesteps does not contain any weights and will always return f32 tensors
464
+ # but time_embedding might actually be running in fp16. so we need to cast here.
465
+ # there might be better ways to encapsulate this.
466
+ t_emb = t_emb.to(dtype=self.dtype)
467
+ emb = self.time_embedding(t_emb)
468
+
469
+ if self.class_embedding is not None:
470
+ if class_labels is None:
471
+ raise ValueError(
472
+ "class_labels should be provided when num_class_embeds > 0"
473
+ )
474
+
475
+ if self.config.class_embed_type == "timestep":
476
+ class_labels = self.time_proj(class_labels)
477
+
478
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
479
+ emb = emb + class_emb
480
+
481
+ # pre-process
482
+ sample = self.conv_in(sample)
483
+ if pose_cond_fea is not None:
484
+ sample = sample + pose_cond_fea
485
+
486
+ # down
487
+ down_block_res_samples = (sample,)
488
+ for downsample_block in self.down_blocks:
489
+ if (
490
+ hasattr(downsample_block, "has_cross_attention")
491
+ and downsample_block.has_cross_attention
492
+ ):
493
+ sample, res_samples = downsample_block(
494
+ hidden_states=sample,
495
+ temb=emb,
496
+ encoder_hidden_states=encoder_hidden_states,
497
+ attention_mask=attention_mask,
498
+ )
499
+ else:
500
+ sample, res_samples = downsample_block(
501
+ hidden_states=sample,
502
+ temb=emb,
503
+ encoder_hidden_states=encoder_hidden_states,
504
+ )
505
+
506
+ down_block_res_samples += res_samples
507
+
508
+ if down_block_additional_residuals is not None:
509
+ new_down_block_res_samples = ()
510
+
511
+ for down_block_res_sample, down_block_additional_residual in zip(
512
+ down_block_res_samples, down_block_additional_residuals
513
+ ):
514
+ down_block_res_sample = (
515
+ down_block_res_sample + down_block_additional_residual
516
+ )
517
+ new_down_block_res_samples += (down_block_res_sample,)
518
+
519
+ down_block_res_samples = new_down_block_res_samples
520
+
521
+ # mid
522
+ sample = self.mid_block(
523
+ sample,
524
+ emb,
525
+ encoder_hidden_states=encoder_hidden_states,
526
+ attention_mask=attention_mask,
527
+ )
528
+
529
+ if mid_block_additional_residual is not None:
530
+ sample = sample + mid_block_additional_residual
531
+
532
+ # up
533
+ for i, upsample_block in enumerate(self.up_blocks):
534
+ is_final_block = i == len(self.up_blocks) - 1
535
+
536
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
537
+ down_block_res_samples = down_block_res_samples[
538
+ : -len(upsample_block.resnets)
539
+ ]
540
+
541
+ # if we have not reached the final block and need to forward the
542
+ # upsample size, we do it here
543
+ if not is_final_block and forward_upsample_size:
544
+ upsample_size = down_block_res_samples[-1].shape[2:]
545
+
546
+ if (
547
+ hasattr(upsample_block, "has_cross_attention")
548
+ and upsample_block.has_cross_attention
549
+ ):
550
+ sample = upsample_block(
551
+ hidden_states=sample,
552
+ temb=emb,
553
+ res_hidden_states_tuple=res_samples,
554
+ encoder_hidden_states=encoder_hidden_states,
555
+ upsample_size=upsample_size,
556
+ attention_mask=attention_mask,
557
+ )
558
+ else:
559
+ sample = upsample_block(
560
+ hidden_states=sample,
561
+ temb=emb,
562
+ res_hidden_states_tuple=res_samples,
563
+ upsample_size=upsample_size,
564
+ encoder_hidden_states=encoder_hidden_states,
565
+ )
566
+
567
+ # post-process
568
+ sample = self.conv_norm_out(sample)
569
+ sample = self.conv_act(sample)
570
+ sample = self.conv_out(sample)
571
+
572
+ if not return_dict:
573
+ return (sample,)
574
+
575
+ return UNet3DConditionOutput(sample=sample)
576
+
577
+ @classmethod
578
+ def from_pretrained_2d(
579
+ cls,
580
+ pretrained_model_path: PathLike,
581
+ motion_module_path: PathLike,
582
+ subfolder=None,
583
+ unet_additional_kwargs=None,
584
+ mm_zero_proj_out=False,
585
+ ):
586
+ pretrained_model_path = Path(pretrained_model_path)
587
+ motion_module_path = Path(motion_module_path)
588
+ if subfolder is not None:
589
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
590
+ logger.info(
591
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
592
+ )
593
+
594
+ config_file = pretrained_model_path / "config.json"
595
+ if not (config_file.exists() and config_file.is_file()):
596
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
597
+
598
+ unet_config = cls.load_config(config_file)
599
+ unet_config["_class_name"] = cls.__name__
600
+ unet_config["down_block_types"] = [
601
+ "CrossAttnDownBlock3D",
602
+ "CrossAttnDownBlock3D",
603
+ "CrossAttnDownBlock3D",
604
+ "DownBlock3D",
605
+ ]
606
+ unet_config["up_block_types"] = [
607
+ "UpBlock3D",
608
+ "CrossAttnUpBlock3D",
609
+ "CrossAttnUpBlock3D",
610
+ "CrossAttnUpBlock3D",
611
+ ]
612
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
613
+
614
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
615
+ # load the vanilla weights
616
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
617
+ logger.debug(
618
+ f"loading safeTensors weights from {pretrained_model_path} ..."
619
+ )
620
+ state_dict = load_file(
621
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
622
+ )
623
+
624
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
625
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
626
+ state_dict = torch.load(
627
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
628
+ map_location="cpu",
629
+ weights_only=True,
630
+ )
631
+ else:
632
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
633
+
634
+ # load the motion module weights
635
+ if motion_module_path.exists() and motion_module_path.is_file():
636
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
637
+ logger.info(f"Load motion module params from {motion_module_path}")
638
+ motion_state_dict = torch.load(
639
+ motion_module_path, map_location="cpu", weights_only=True
640
+ )
641
+ elif motion_module_path.suffix.lower() == ".safetensors":
642
+ motion_state_dict = load_file(motion_module_path, device="cpu")
643
+ else:
644
+ raise RuntimeError(
645
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
646
+ )
647
+ if mm_zero_proj_out:
648
+ logger.info(f"Zero initialize proj_out layers in motion module...")
649
+ new_motion_state_dict = OrderedDict()
650
+ for k in motion_state_dict:
651
+ if "proj_out" in k:
652
+ continue
653
+ new_motion_state_dict[k] = motion_state_dict[k]
654
+ motion_state_dict = new_motion_state_dict
655
+
656
+ # merge the state dicts
657
+ state_dict.update(motion_state_dict)
658
+
659
+ # load the weights into the model
660
+ m, u = model.load_state_dict(state_dict, strict=False)
661
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
662
+
663
+ params = [
664
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
665
+ ]
666
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
667
+
668
+ return model
src/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,862 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import pdb
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .motion_module import get_motion_module
9
+
10
+ # from .motion_module import get_motion_module
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ from .transformer_3d import Transformer3DModel
13
+
14
+
15
+ def get_down_block(
16
+ down_block_type,
17
+ num_layers,
18
+ in_channels,
19
+ out_channels,
20
+ temb_channels,
21
+ add_downsample,
22
+ resnet_eps,
23
+ resnet_act_fn,
24
+ attn_num_head_channels,
25
+ resnet_groups=None,
26
+ cross_attention_dim=None,
27
+ downsample_padding=None,
28
+ dual_cross_attention=False,
29
+ use_linear_projection=False,
30
+ only_cross_attention=False,
31
+ upcast_attention=False,
32
+ resnet_time_scale_shift="default",
33
+ unet_use_cross_frame_attention=None,
34
+ unet_use_temporal_attention=None,
35
+ use_inflated_groupnorm=None,
36
+ use_motion_module=None,
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = (
41
+ down_block_type[7:]
42
+ if down_block_type.startswith("UNetRes")
43
+ else down_block_type
44
+ )
45
+ if down_block_type == "DownBlock3D":
46
+ return DownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ resnet_time_scale_shift=resnet_time_scale_shift,
57
+ use_inflated_groupnorm=use_inflated_groupnorm,
58
+ use_motion_module=use_motion_module,
59
+ motion_module_type=motion_module_type,
60
+ motion_module_kwargs=motion_module_kwargs,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock3D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError(
65
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
66
+ )
67
+ return CrossAttnDownBlock3D(
68
+ num_layers=num_layers,
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ temb_channels=temb_channels,
72
+ add_downsample=add_downsample,
73
+ resnet_eps=resnet_eps,
74
+ resnet_act_fn=resnet_act_fn,
75
+ resnet_groups=resnet_groups,
76
+ downsample_padding=downsample_padding,
77
+ cross_attention_dim=cross_attention_dim,
78
+ attn_num_head_channels=attn_num_head_channels,
79
+ dual_cross_attention=dual_cross_attention,
80
+ use_linear_projection=use_linear_projection,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+ resnet_time_scale_shift=resnet_time_scale_shift,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ use_inflated_groupnorm=use_inflated_groupnorm,
87
+ use_motion_module=use_motion_module,
88
+ motion_module_type=motion_module_type,
89
+ motion_module_kwargs=motion_module_kwargs,
90
+ )
91
+ raise ValueError(f"{down_block_type} does not exist.")
92
+
93
+
94
+ def get_up_block(
95
+ up_block_type,
96
+ num_layers,
97
+ in_channels,
98
+ out_channels,
99
+ prev_output_channel,
100
+ temb_channels,
101
+ add_upsample,
102
+ resnet_eps,
103
+ resnet_act_fn,
104
+ attn_num_head_channels,
105
+ resnet_groups=None,
106
+ cross_attention_dim=None,
107
+ dual_cross_attention=False,
108
+ use_linear_projection=False,
109
+ only_cross_attention=False,
110
+ upcast_attention=False,
111
+ resnet_time_scale_shift="default",
112
+ unet_use_cross_frame_attention=None,
113
+ unet_use_temporal_attention=None,
114
+ use_inflated_groupnorm=None,
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = (
120
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
121
+ )
122
+ if up_block_type == "UpBlock3D":
123
+ return UpBlock3D(
124
+ num_layers=num_layers,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ prev_output_channel=prev_output_channel,
128
+ temb_channels=temb_channels,
129
+ add_upsample=add_upsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ resnet_time_scale_shift=resnet_time_scale_shift,
134
+ use_inflated_groupnorm=use_inflated_groupnorm,
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError(
142
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
143
+ )
144
+ return CrossAttnUpBlock3D(
145
+ num_layers=num_layers,
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ prev_output_channel=prev_output_channel,
149
+ temb_channels=temb_channels,
150
+ add_upsample=add_upsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ cross_attention_dim=cross_attention_dim,
155
+ attn_num_head_channels=attn_num_head_channels,
156
+ dual_cross_attention=dual_cross_attention,
157
+ use_linear_projection=use_linear_projection,
158
+ only_cross_attention=only_cross_attention,
159
+ upcast_attention=upcast_attention,
160
+ resnet_time_scale_shift=resnet_time_scale_shift,
161
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
162
+ unet_use_temporal_attention=unet_use_temporal_attention,
163
+ use_inflated_groupnorm=use_inflated_groupnorm,
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+ unet_use_cross_frame_attention=None,
190
+ unet_use_temporal_attention=None,
191
+ use_inflated_groupnorm=None,
192
+ use_motion_module=None,
193
+ motion_module_type=None,
194
+ motion_module_kwargs=None,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.has_cross_attention = True
199
+ self.attn_num_head_channels = attn_num_head_channels
200
+ resnet_groups = (
201
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
202
+ )
203
+
204
+ # there is always at least one resnet
205
+ resnets = [
206
+ ResnetBlock3D(
207
+ in_channels=in_channels,
208
+ out_channels=in_channels,
209
+ temb_channels=temb_channels,
210
+ eps=resnet_eps,
211
+ groups=resnet_groups,
212
+ dropout=dropout,
213
+ time_embedding_norm=resnet_time_scale_shift,
214
+ non_linearity=resnet_act_fn,
215
+ output_scale_factor=output_scale_factor,
216
+ pre_norm=resnet_pre_norm,
217
+ use_inflated_groupnorm=use_inflated_groupnorm,
218
+ )
219
+ ]
220
+ attentions = []
221
+ motion_modules = []
222
+
223
+ for _ in range(num_layers):
224
+ if dual_cross_attention:
225
+ raise NotImplementedError
226
+ attentions.append(
227
+ Transformer3DModel(
228
+ attn_num_head_channels,
229
+ in_channels // attn_num_head_channels,
230
+ in_channels=in_channels,
231
+ num_layers=1,
232
+ cross_attention_dim=cross_attention_dim,
233
+ norm_num_groups=resnet_groups,
234
+ use_linear_projection=use_linear_projection,
235
+ upcast_attention=upcast_attention,
236
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
237
+ unet_use_temporal_attention=unet_use_temporal_attention,
238
+ )
239
+ )
240
+ motion_modules.append(
241
+ get_motion_module(
242
+ in_channels=in_channels,
243
+ motion_module_type=motion_module_type,
244
+ motion_module_kwargs=motion_module_kwargs,
245
+ )
246
+ if use_motion_module
247
+ else None
248
+ )
249
+ resnets.append(
250
+ ResnetBlock3D(
251
+ in_channels=in_channels,
252
+ out_channels=in_channels,
253
+ temb_channels=temb_channels,
254
+ eps=resnet_eps,
255
+ groups=resnet_groups,
256
+ dropout=dropout,
257
+ time_embedding_norm=resnet_time_scale_shift,
258
+ non_linearity=resnet_act_fn,
259
+ output_scale_factor=output_scale_factor,
260
+ pre_norm=resnet_pre_norm,
261
+ use_inflated_groupnorm=use_inflated_groupnorm,
262
+ )
263
+ )
264
+
265
+ self.attentions = nn.ModuleList(attentions)
266
+ self.resnets = nn.ModuleList(resnets)
267
+ self.motion_modules = nn.ModuleList(motion_modules)
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states,
272
+ temb=None,
273
+ encoder_hidden_states=None,
274
+ attention_mask=None,
275
+ ):
276
+ hidden_states = self.resnets[0](hidden_states, temb)
277
+ for attn, resnet, motion_module in zip(
278
+ self.attentions, self.resnets[1:], self.motion_modules
279
+ ):
280
+ hidden_states = attn(
281
+ hidden_states,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ ).sample
284
+ hidden_states = (
285
+ motion_module(
286
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
287
+ )
288
+ if motion_module is not None
289
+ else hidden_states
290
+ )
291
+ hidden_states = resnet(hidden_states, temb)
292
+
293
+ return hidden_states
294
+
295
+
296
+ class CrossAttnDownBlock3D(nn.Module):
297
+ def __init__(
298
+ self,
299
+ in_channels: int,
300
+ out_channels: int,
301
+ temb_channels: int,
302
+ dropout: float = 0.0,
303
+ num_layers: int = 1,
304
+ resnet_eps: float = 1e-6,
305
+ resnet_time_scale_shift: str = "default",
306
+ resnet_act_fn: str = "swish",
307
+ resnet_groups: int = 32,
308
+ resnet_pre_norm: bool = True,
309
+ attn_num_head_channels=1,
310
+ cross_attention_dim=1280,
311
+ output_scale_factor=1.0,
312
+ downsample_padding=1,
313
+ add_downsample=True,
314
+ dual_cross_attention=False,
315
+ use_linear_projection=False,
316
+ only_cross_attention=False,
317
+ upcast_attention=False,
318
+ unet_use_cross_frame_attention=None,
319
+ unet_use_temporal_attention=None,
320
+ use_inflated_groupnorm=None,
321
+ use_motion_module=None,
322
+ motion_module_type=None,
323
+ motion_module_kwargs=None,
324
+ ):
325
+ super().__init__()
326
+ resnets = []
327
+ attentions = []
328
+ motion_modules = []
329
+
330
+ self.has_cross_attention = True
331
+ self.attn_num_head_channels = attn_num_head_channels
332
+
333
+ for i in range(num_layers):
334
+ in_channels = in_channels if i == 0 else out_channels
335
+ resnets.append(
336
+ ResnetBlock3D(
337
+ in_channels=in_channels,
338
+ out_channels=out_channels,
339
+ temb_channels=temb_channels,
340
+ eps=resnet_eps,
341
+ groups=resnet_groups,
342
+ dropout=dropout,
343
+ time_embedding_norm=resnet_time_scale_shift,
344
+ non_linearity=resnet_act_fn,
345
+ output_scale_factor=output_scale_factor,
346
+ pre_norm=resnet_pre_norm,
347
+ use_inflated_groupnorm=use_inflated_groupnorm,
348
+ )
349
+ )
350
+ if dual_cross_attention:
351
+ raise NotImplementedError
352
+ attentions.append(
353
+ Transformer3DModel(
354
+ attn_num_head_channels,
355
+ out_channels // attn_num_head_channels,
356
+ in_channels=out_channels,
357
+ num_layers=1,
358
+ cross_attention_dim=cross_attention_dim,
359
+ norm_num_groups=resnet_groups,
360
+ use_linear_projection=use_linear_projection,
361
+ only_cross_attention=only_cross_attention,
362
+ upcast_attention=upcast_attention,
363
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
364
+ unet_use_temporal_attention=unet_use_temporal_attention,
365
+ )
366
+ )
367
+ motion_modules.append(
368
+ get_motion_module(
369
+ in_channels=out_channels,
370
+ motion_module_type=motion_module_type,
371
+ motion_module_kwargs=motion_module_kwargs,
372
+ )
373
+ if use_motion_module
374
+ else None
375
+ )
376
+
377
+ self.attentions = nn.ModuleList(attentions)
378
+ self.resnets = nn.ModuleList(resnets)
379
+ self.motion_modules = nn.ModuleList(motion_modules)
380
+
381
+ if add_downsample:
382
+ self.downsamplers = nn.ModuleList(
383
+ [
384
+ Downsample3D(
385
+ out_channels,
386
+ use_conv=True,
387
+ out_channels=out_channels,
388
+ padding=downsample_padding,
389
+ name="op",
390
+ )
391
+ ]
392
+ )
393
+ else:
394
+ self.downsamplers = None
395
+
396
+ self.gradient_checkpointing = False
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states,
401
+ temb=None,
402
+ encoder_hidden_states=None,
403
+ attention_mask=None,
404
+ ):
405
+ output_states = ()
406
+
407
+ for i, (resnet, attn, motion_module) in enumerate(
408
+ zip(self.resnets, self.attentions, self.motion_modules)
409
+ ):
410
+ # self.gradient_checkpointing = False
411
+ if self.training and self.gradient_checkpointing:
412
+
413
+ def create_custom_forward(module, return_dict=None):
414
+ def custom_forward(*inputs):
415
+ if return_dict is not None:
416
+ return module(*inputs, return_dict=return_dict)
417
+ else:
418
+ return module(*inputs)
419
+
420
+ return custom_forward
421
+
422
+ hidden_states = torch.utils.checkpoint.checkpoint(
423
+ create_custom_forward(resnet), hidden_states, temb
424
+ )
425
+ hidden_states = torch.utils.checkpoint.checkpoint(
426
+ create_custom_forward(attn, return_dict=False),
427
+ hidden_states,
428
+ encoder_hidden_states,
429
+ )[0]
430
+
431
+ # add motion module
432
+ hidden_states = (
433
+ motion_module(
434
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
435
+ )
436
+ if motion_module is not None
437
+ else hidden_states
438
+ )
439
+
440
+ else:
441
+ hidden_states = resnet(hidden_states, temb)
442
+ hidden_states = attn(
443
+ hidden_states,
444
+ encoder_hidden_states=encoder_hidden_states,
445
+ ).sample
446
+
447
+ # add motion module
448
+ hidden_states = (
449
+ motion_module(
450
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
451
+ )
452
+ if motion_module is not None
453
+ else hidden_states
454
+ )
455
+
456
+ output_states += (hidden_states,)
457
+
458
+ if self.downsamplers is not None:
459
+ for downsampler in self.downsamplers:
460
+ hidden_states = downsampler(hidden_states)
461
+
462
+ output_states += (hidden_states,)
463
+
464
+ return hidden_states, output_states
465
+
466
+
467
+ class DownBlock3D(nn.Module):
468
+ def __init__(
469
+ self,
470
+ in_channels: int,
471
+ out_channels: int,
472
+ temb_channels: int,
473
+ dropout: float = 0.0,
474
+ num_layers: int = 1,
475
+ resnet_eps: float = 1e-6,
476
+ resnet_time_scale_shift: str = "default",
477
+ resnet_act_fn: str = "swish",
478
+ resnet_groups: int = 32,
479
+ resnet_pre_norm: bool = True,
480
+ output_scale_factor=1.0,
481
+ add_downsample=True,
482
+ downsample_padding=1,
483
+ use_inflated_groupnorm=None,
484
+ use_motion_module=None,
485
+ motion_module_type=None,
486
+ motion_module_kwargs=None,
487
+ ):
488
+ super().__init__()
489
+ resnets = []
490
+ motion_modules = []
491
+
492
+ # use_motion_module = False
493
+ for i in range(num_layers):
494
+ in_channels = in_channels if i == 0 else out_channels
495
+ resnets.append(
496
+ ResnetBlock3D(
497
+ in_channels=in_channels,
498
+ out_channels=out_channels,
499
+ temb_channels=temb_channels,
500
+ eps=resnet_eps,
501
+ groups=resnet_groups,
502
+ dropout=dropout,
503
+ time_embedding_norm=resnet_time_scale_shift,
504
+ non_linearity=resnet_act_fn,
505
+ output_scale_factor=output_scale_factor,
506
+ pre_norm=resnet_pre_norm,
507
+ use_inflated_groupnorm=use_inflated_groupnorm,
508
+ )
509
+ )
510
+ motion_modules.append(
511
+ get_motion_module(
512
+ in_channels=out_channels,
513
+ motion_module_type=motion_module_type,
514
+ motion_module_kwargs=motion_module_kwargs,
515
+ )
516
+ if use_motion_module
517
+ else None
518
+ )
519
+
520
+ self.resnets = nn.ModuleList(resnets)
521
+ self.motion_modules = nn.ModuleList(motion_modules)
522
+
523
+ if add_downsample:
524
+ self.downsamplers = nn.ModuleList(
525
+ [
526
+ Downsample3D(
527
+ out_channels,
528
+ use_conv=True,
529
+ out_channels=out_channels,
530
+ padding=downsample_padding,
531
+ name="op",
532
+ )
533
+ ]
534
+ )
535
+ else:
536
+ self.downsamplers = None
537
+
538
+ self.gradient_checkpointing = False
539
+
540
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
541
+ output_states = ()
542
+
543
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
544
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
545
+ if self.training and self.gradient_checkpointing:
546
+
547
+ def create_custom_forward(module):
548
+ def custom_forward(*inputs):
549
+ return module(*inputs)
550
+
551
+ return custom_forward
552
+
553
+ hidden_states = torch.utils.checkpoint.checkpoint(
554
+ create_custom_forward(resnet), hidden_states, temb
555
+ )
556
+ if motion_module is not None:
557
+ hidden_states = torch.utils.checkpoint.checkpoint(
558
+ create_custom_forward(motion_module),
559
+ hidden_states.requires_grad_(),
560
+ temb,
561
+ encoder_hidden_states,
562
+ )
563
+ else:
564
+ hidden_states = resnet(hidden_states, temb)
565
+
566
+ # add motion module
567
+ hidden_states = (
568
+ motion_module(
569
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
570
+ )
571
+ if motion_module is not None
572
+ else hidden_states
573
+ )
574
+
575
+ output_states += (hidden_states,)
576
+
577
+ if self.downsamplers is not None:
578
+ for downsampler in self.downsamplers:
579
+ hidden_states = downsampler(hidden_states)
580
+
581
+ output_states += (hidden_states,)
582
+
583
+ return hidden_states, output_states
584
+
585
+
586
+ class CrossAttnUpBlock3D(nn.Module):
587
+ def __init__(
588
+ self,
589
+ in_channels: int,
590
+ out_channels: int,
591
+ prev_output_channel: int,
592
+ temb_channels: int,
593
+ dropout: float = 0.0,
594
+ num_layers: int = 1,
595
+ resnet_eps: float = 1e-6,
596
+ resnet_time_scale_shift: str = "default",
597
+ resnet_act_fn: str = "swish",
598
+ resnet_groups: int = 32,
599
+ resnet_pre_norm: bool = True,
600
+ attn_num_head_channels=1,
601
+ cross_attention_dim=1280,
602
+ output_scale_factor=1.0,
603
+ add_upsample=True,
604
+ dual_cross_attention=False,
605
+ use_linear_projection=False,
606
+ only_cross_attention=False,
607
+ upcast_attention=False,
608
+ unet_use_cross_frame_attention=None,
609
+ unet_use_temporal_attention=None,
610
+ use_motion_module=None,
611
+ use_inflated_groupnorm=None,
612
+ motion_module_type=None,
613
+ motion_module_kwargs=None,
614
+ ):
615
+ super().__init__()
616
+ resnets = []
617
+ attentions = []
618
+ motion_modules = []
619
+
620
+ self.has_cross_attention = True
621
+ self.attn_num_head_channels = attn_num_head_channels
622
+
623
+ for i in range(num_layers):
624
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
625
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
626
+
627
+ resnets.append(
628
+ ResnetBlock3D(
629
+ in_channels=resnet_in_channels + res_skip_channels,
630
+ out_channels=out_channels,
631
+ temb_channels=temb_channels,
632
+ eps=resnet_eps,
633
+ groups=resnet_groups,
634
+ dropout=dropout,
635
+ time_embedding_norm=resnet_time_scale_shift,
636
+ non_linearity=resnet_act_fn,
637
+ output_scale_factor=output_scale_factor,
638
+ pre_norm=resnet_pre_norm,
639
+ use_inflated_groupnorm=use_inflated_groupnorm,
640
+ )
641
+ )
642
+ if dual_cross_attention:
643
+ raise NotImplementedError
644
+ attentions.append(
645
+ Transformer3DModel(
646
+ attn_num_head_channels,
647
+ out_channels // attn_num_head_channels,
648
+ in_channels=out_channels,
649
+ num_layers=1,
650
+ cross_attention_dim=cross_attention_dim,
651
+ norm_num_groups=resnet_groups,
652
+ use_linear_projection=use_linear_projection,
653
+ only_cross_attention=only_cross_attention,
654
+ upcast_attention=upcast_attention,
655
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
656
+ unet_use_temporal_attention=unet_use_temporal_attention,
657
+ )
658
+ )
659
+ motion_modules.append(
660
+ get_motion_module(
661
+ in_channels=out_channels,
662
+ motion_module_type=motion_module_type,
663
+ motion_module_kwargs=motion_module_kwargs,
664
+ )
665
+ if use_motion_module
666
+ else None
667
+ )
668
+
669
+ self.attentions = nn.ModuleList(attentions)
670
+ self.resnets = nn.ModuleList(resnets)
671
+ self.motion_modules = nn.ModuleList(motion_modules)
672
+
673
+ if add_upsample:
674
+ self.upsamplers = nn.ModuleList(
675
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
676
+ )
677
+ else:
678
+ self.upsamplers = None
679
+
680
+ self.gradient_checkpointing = False
681
+
682
+ def forward(
683
+ self,
684
+ hidden_states,
685
+ res_hidden_states_tuple,
686
+ temb=None,
687
+ encoder_hidden_states=None,
688
+ upsample_size=None,
689
+ attention_mask=None,
690
+ ):
691
+ for i, (resnet, attn, motion_module) in enumerate(
692
+ zip(self.resnets, self.attentions, self.motion_modules)
693
+ ):
694
+ # pop res hidden states
695
+ res_hidden_states = res_hidden_states_tuple[-1]
696
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
697
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
698
+
699
+ if self.training and self.gradient_checkpointing:
700
+
701
+ def create_custom_forward(module, return_dict=None):
702
+ def custom_forward(*inputs):
703
+ if return_dict is not None:
704
+ return module(*inputs, return_dict=return_dict)
705
+ else:
706
+ return module(*inputs)
707
+
708
+ return custom_forward
709
+
710
+ hidden_states = torch.utils.checkpoint.checkpoint(
711
+ create_custom_forward(resnet), hidden_states, temb
712
+ )
713
+ hidden_states = attn(
714
+ hidden_states,
715
+ encoder_hidden_states=encoder_hidden_states,
716
+ ).sample
717
+ if motion_module is not None:
718
+ hidden_states = torch.utils.checkpoint.checkpoint(
719
+ create_custom_forward(motion_module),
720
+ hidden_states.requires_grad_(),
721
+ temb,
722
+ encoder_hidden_states,
723
+ )
724
+
725
+ else:
726
+ hidden_states = resnet(hidden_states, temb)
727
+ hidden_states = attn(
728
+ hidden_states,
729
+ encoder_hidden_states=encoder_hidden_states,
730
+ ).sample
731
+
732
+ # add motion module
733
+ hidden_states = (
734
+ motion_module(
735
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
736
+ )
737
+ if motion_module is not None
738
+ else hidden_states
739
+ )
740
+
741
+ if self.upsamplers is not None:
742
+ for upsampler in self.upsamplers:
743
+ hidden_states = upsampler(hidden_states, upsample_size)
744
+
745
+ return hidden_states
746
+
747
+
748
+ class UpBlock3D(nn.Module):
749
+ def __init__(
750
+ self,
751
+ in_channels: int,
752
+ prev_output_channel: int,
753
+ out_channels: int,
754
+ temb_channels: int,
755
+ dropout: float = 0.0,
756
+ num_layers: int = 1,
757
+ resnet_eps: float = 1e-6,
758
+ resnet_time_scale_shift: str = "default",
759
+ resnet_act_fn: str = "swish",
760
+ resnet_groups: int = 32,
761
+ resnet_pre_norm: bool = True,
762
+ output_scale_factor=1.0,
763
+ add_upsample=True,
764
+ use_inflated_groupnorm=None,
765
+ use_motion_module=None,
766
+ motion_module_type=None,
767
+ motion_module_kwargs=None,
768
+ ):
769
+ super().__init__()
770
+ resnets = []
771
+ motion_modules = []
772
+
773
+ # use_motion_module = False
774
+ for i in range(num_layers):
775
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
776
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
777
+
778
+ resnets.append(
779
+ ResnetBlock3D(
780
+ in_channels=resnet_in_channels + res_skip_channels,
781
+ out_channels=out_channels,
782
+ temb_channels=temb_channels,
783
+ eps=resnet_eps,
784
+ groups=resnet_groups,
785
+ dropout=dropout,
786
+ time_embedding_norm=resnet_time_scale_shift,
787
+ non_linearity=resnet_act_fn,
788
+ output_scale_factor=output_scale_factor,
789
+ pre_norm=resnet_pre_norm,
790
+ use_inflated_groupnorm=use_inflated_groupnorm,
791
+ )
792
+ )
793
+ motion_modules.append(
794
+ get_motion_module(
795
+ in_channels=out_channels,
796
+ motion_module_type=motion_module_type,
797
+ motion_module_kwargs=motion_module_kwargs,
798
+ )
799
+ if use_motion_module
800
+ else None
801
+ )
802
+
803
+ self.resnets = nn.ModuleList(resnets)
804
+ self.motion_modules = nn.ModuleList(motion_modules)
805
+
806
+ if add_upsample:
807
+ self.upsamplers = nn.ModuleList(
808
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
809
+ )
810
+ else:
811
+ self.upsamplers = None
812
+
813
+ self.gradient_checkpointing = False
814
+
815
+ def forward(
816
+ self,
817
+ hidden_states,
818
+ res_hidden_states_tuple,
819
+ temb=None,
820
+ upsample_size=None,
821
+ encoder_hidden_states=None,
822
+ ):
823
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
824
+ # pop res hidden states
825
+ res_hidden_states = res_hidden_states_tuple[-1]
826
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
827
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
828
+
829
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
830
+ if self.training and self.gradient_checkpointing:
831
+
832
+ def create_custom_forward(module):
833
+ def custom_forward(*inputs):
834
+ return module(*inputs)
835
+
836
+ return custom_forward
837
+
838
+ hidden_states = torch.utils.checkpoint.checkpoint(
839
+ create_custom_forward(resnet), hidden_states, temb
840
+ )
841
+ if motion_module is not None:
842
+ hidden_states = torch.utils.checkpoint.checkpoint(
843
+ create_custom_forward(motion_module),
844
+ hidden_states.requires_grad_(),
845
+ temb,
846
+ encoder_hidden_states,
847
+ )
848
+ else:
849
+ hidden_states = resnet(hidden_states, temb)
850
+ hidden_states = (
851
+ motion_module(
852
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
853
+ )
854
+ if motion_module is not None
855
+ else hidden_states
856
+ )
857
+
858
+ if self.upsamplers is not None:
859
+ for upsampler in self.upsamplers:
860
+ hidden_states = upsampler(hidden_states, upsample_size)
861
+
862
+ return hidden_states
src/pipelines/__init__.py ADDED
File without changes
src/pipelines/context.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Adapted from cli
2
+ from typing import Callable, List, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ def ordered_halving(val):
8
+ bin_str = f"{val:064b}"
9
+ bin_flip = bin_str[::-1]
10
+ as_int = int(bin_flip, 2)
11
+
12
+ return as_int / (1 << 64)
13
+
14
+
15
+ def uniform(
16
+ step: int = ...,
17
+ num_steps: Optional[int] = None,
18
+ num_frames: int = ...,
19
+ context_size: Optional[int] = None,
20
+ context_stride: int = 3,
21
+ context_overlap: int = 4,
22
+ closed_loop: bool = True,
23
+ ):
24
+ if num_frames <= context_size:
25
+ yield list(range(num_frames))
26
+ return
27
+
28
+ context_stride = min(
29
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30
+ )
31
+
32
+ for context_step in 1 << np.arange(context_stride):
33
+ pad = int(round(num_frames * ordered_halving(step)))
34
+ for j in range(
35
+ int(ordered_halving(step) * context_step) + pad,
36
+ num_frames + pad + (0 if closed_loop else -context_overlap),
37
+ (context_size * context_step - context_overlap),
38
+ ):
39
+ yield [
40
+ e % num_frames
41
+ for e in range(j, j + context_size * context_step, context_step)
42
+ ]
43
+
44
+
45
+ def get_context_scheduler(name: str) -> Callable:
46
+ if name == "uniform":
47
+ return uniform
48
+ else:
49
+ raise ValueError(f"Unknown context_overlap policy {name}")
50
+
51
+
52
+ def get_total_steps(
53
+ scheduler,
54
+ timesteps: List[int],
55
+ num_steps: Optional[int] = None,
56
+ num_frames: int = ...,
57
+ context_size: Optional[int] = None,
58
+ context_stride: int = 3,
59
+ context_overlap: int = 4,
60
+ closed_loop: bool = True,
61
+ ):
62
+ return sum(
63
+ len(
64
+ list(
65
+ scheduler(
66
+ i,
67
+ num_steps,
68
+ num_frames,
69
+ context_size,
70
+ context_stride,
71
+ context_overlap,
72
+ )
73
+ )
74
+ )
75
+ for i in range(len(timesteps))
76
+ )
src/pipelines/pipeline_pose2vid.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
10
+ EulerAncestralDiscreteScheduler,
11
+ EulerDiscreteScheduler, LMSDiscreteScheduler,
12
+ PNDMScheduler)
13
+ from diffusers.utils import BaseOutput, is_accelerate_available
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from einops import rearrange
16
+ from tqdm import tqdm
17
+ from transformers import CLIPImageProcessor
18
+
19
+ from src.models.mutual_self_attention import ReferenceAttentionControl
20
+
21
+
22
+ @dataclass
23
+ class Pose2VideoPipelineOutput(BaseOutput):
24
+ videos: Union[torch.Tensor, np.ndarray]
25
+ middle_results: Union[torch.Tensor, np.ndarray]
26
+
27
+
28
+ class Pose2VideoPipeline(DiffusionPipeline):
29
+ _optional_components = []
30
+
31
+ def __init__(
32
+ self,
33
+ vae,
34
+ image_encoder,
35
+ reference_unet,
36
+ denoising_unet,
37
+ pose_guider,
38
+ scheduler: Union[
39
+ DDIMScheduler,
40
+ PNDMScheduler,
41
+ LMSDiscreteScheduler,
42
+ EulerDiscreteScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ DPMSolverMultistepScheduler,
45
+ ],
46
+ image_proj_model=None,
47
+ tokenizer=None,
48
+ text_encoder=None,
49
+ ):
50
+ super().__init__()
51
+
52
+ self.register_modules(
53
+ vae=vae,
54
+ image_encoder=image_encoder,
55
+ reference_unet=reference_unet,
56
+ denoising_unet=denoising_unet,
57
+ pose_guider=pose_guider,
58
+ scheduler=scheduler,
59
+ image_proj_model=image_proj_model,
60
+ tokenizer=tokenizer,
61
+ text_encoder=text_encoder,
62
+ )
63
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
64
+ self.clip_image_processor = CLIPImageProcessor()
65
+ self.ref_image_processor = VaeImageProcessor(
66
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
67
+ )
68
+ self.cond_image_processor = VaeImageProcessor(
69
+ vae_scale_factor=self.vae_scale_factor,
70
+ do_convert_rgb=True,
71
+ do_normalize=False,
72
+ )
73
+
74
+ def enable_vae_slicing(self):
75
+ self.vae.enable_slicing()
76
+
77
+ def disable_vae_slicing(self):
78
+ self.vae.disable_slicing()
79
+
80
+ def enable_sequential_cpu_offload(self, gpu_id=0):
81
+ if is_accelerate_available():
82
+ from accelerate import cpu_offload
83
+ else:
84
+ raise ImportError("Please install accelerate via `pip install accelerate`")
85
+
86
+ device = torch.device(f"cuda:{gpu_id}")
87
+
88
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
89
+ if cpu_offloaded_model is not None:
90
+ cpu_offload(cpu_offloaded_model, device)
91
+
92
+ @property
93
+ def _execution_device(self):
94
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
95
+ return self.device
96
+ for module in self.unet.modules():
97
+ if (
98
+ hasattr(module, "_hf_hook")
99
+ and hasattr(module._hf_hook, "execution_device")
100
+ and module._hf_hook.execution_device is not None
101
+ ):
102
+ return torch.device(module._hf_hook.execution_device)
103
+ return self.device
104
+
105
+ def decode_latents(self, latents):
106
+ video_length = latents.shape[2]
107
+ latents = 1 / 0.18215 * latents
108
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
109
+ # video = self.vae.decode(latents).sample
110
+ video = []
111
+ for frame_idx in tqdm(range(latents.shape[0])):
112
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
113
+ video = torch.cat(video)
114
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
115
+ video = (video / 2 + 0.5).clamp(0, 1)
116
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
117
+ video = video.cpu().float().numpy()
118
+ return video
119
+
120
+ def prepare_extra_step_kwargs(self, generator, eta):
121
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
122
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
123
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
124
+ # and should be between [0, 1]
125
+
126
+ accepts_eta = "eta" in set(
127
+ inspect.signature(self.scheduler.step).parameters.keys()
128
+ )
129
+ extra_step_kwargs = {}
130
+ if accepts_eta:
131
+ extra_step_kwargs["eta"] = eta
132
+
133
+ # check if the scheduler accepts generator
134
+ accepts_generator = "generator" in set(
135
+ inspect.signature(self.scheduler.step).parameters.keys()
136
+ )
137
+ if accepts_generator:
138
+ extra_step_kwargs["generator"] = generator
139
+ return extra_step_kwargs
140
+
141
+ def prepare_latents(
142
+ self,
143
+ batch_size,
144
+ num_channels_latents,
145
+ width,
146
+ height,
147
+ video_length,
148
+ dtype,
149
+ device,
150
+ generator,
151
+ latents=None,
152
+ ):
153
+ shape = (
154
+ batch_size,
155
+ num_channels_latents,
156
+ video_length,
157
+ height // self.vae_scale_factor,
158
+ width // self.vae_scale_factor,
159
+ )
160
+ if isinstance(generator, list) and len(generator) != batch_size:
161
+ raise ValueError(
162
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
163
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
164
+ )
165
+
166
+ if latents is None:
167
+ latents = randn_tensor(
168
+ shape, generator=generator, device=device, dtype=dtype
169
+ )
170
+ else:
171
+ latents = latents.to(device)
172
+
173
+ # scale the initial noise by the standard deviation required by the scheduler
174
+ latents = latents * self.scheduler.init_noise_sigma
175
+ return latents
176
+
177
+ def _encode_prompt(
178
+ self,
179
+ prompt,
180
+ device,
181
+ num_videos_per_prompt,
182
+ do_classifier_free_guidance,
183
+ negative_prompt,
184
+ ):
185
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
186
+
187
+ text_inputs = self.tokenizer(
188
+ prompt,
189
+ padding="max_length",
190
+ max_length=self.tokenizer.model_max_length,
191
+ truncation=True,
192
+ return_tensors="pt",
193
+ )
194
+ text_input_ids = text_inputs.input_ids
195
+ untruncated_ids = self.tokenizer(
196
+ prompt, padding="longest", return_tensors="pt"
197
+ ).input_ids
198
+
199
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
200
+ text_input_ids, untruncated_ids
201
+ ):
202
+ removed_text = self.tokenizer.batch_decode(
203
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
204
+ )
205
+
206
+ if (
207
+ hasattr(self.text_encoder.config, "use_attention_mask")
208
+ and self.text_encoder.config.use_attention_mask
209
+ ):
210
+ attention_mask = text_inputs.attention_mask.to(device)
211
+ else:
212
+ attention_mask = None
213
+
214
+ text_embeddings = self.text_encoder(
215
+ text_input_ids.to(device),
216
+ attention_mask=attention_mask,
217
+ )
218
+ text_embeddings = text_embeddings[0]
219
+
220
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
221
+ bs_embed, seq_len, _ = text_embeddings.shape
222
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
223
+ text_embeddings = text_embeddings.view(
224
+ bs_embed * num_videos_per_prompt, seq_len, -1
225
+ )
226
+
227
+ # get unconditional embeddings for classifier free guidance
228
+ if do_classifier_free_guidance:
229
+ uncond_tokens: List[str]
230
+ if negative_prompt is None:
231
+ uncond_tokens = [""] * batch_size
232
+ elif type(prompt) is not type(negative_prompt):
233
+ raise TypeError(
234
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
235
+ f" {type(prompt)}."
236
+ )
237
+ elif isinstance(negative_prompt, str):
238
+ uncond_tokens = [negative_prompt]
239
+ elif batch_size != len(negative_prompt):
240
+ raise ValueError(
241
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
242
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
243
+ " the batch size of `prompt`."
244
+ )
245
+ else:
246
+ uncond_tokens = negative_prompt
247
+
248
+ max_length = text_input_ids.shape[-1]
249
+ uncond_input = self.tokenizer(
250
+ uncond_tokens,
251
+ padding="max_length",
252
+ max_length=max_length,
253
+ truncation=True,
254
+ return_tensors="pt",
255
+ )
256
+
257
+ if (
258
+ hasattr(self.text_encoder.config, "use_attention_mask")
259
+ and self.text_encoder.config.use_attention_mask
260
+ ):
261
+ attention_mask = uncond_input.attention_mask.to(device)
262
+ else:
263
+ attention_mask = None
264
+
265
+ uncond_embeddings = self.text_encoder(
266
+ uncond_input.input_ids.to(device),
267
+ attention_mask=attention_mask,
268
+ )
269
+ uncond_embeddings = uncond_embeddings[0]
270
+
271
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
272
+ seq_len = uncond_embeddings.shape[1]
273
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
274
+ uncond_embeddings = uncond_embeddings.view(
275
+ batch_size * num_videos_per_prompt, seq_len, -1
276
+ )
277
+
278
+ # For classifier free guidance, we need to do two forward passes.
279
+ # Here we concatenate the unconditional and text embeddings into a single batch
280
+ # to avoid doing two forward passes
281
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
282
+
283
+ return text_embeddings
284
+
285
+ @torch.no_grad()
286
+ def __call__(
287
+ self,
288
+ ref_image,
289
+ pose_images,
290
+ width,
291
+ height,
292
+ video_length,
293
+ num_inference_steps,
294
+ guidance_scale,
295
+ num_images_per_prompt=1,
296
+ eta: float = 0.0,
297
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
298
+ output_type: Optional[str] = "tensor",
299
+ return_dict: bool = True,
300
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
301
+ callback_steps: Optional[int] = 1,
302
+ **kwargs,
303
+ ):
304
+ # Default height and width to unet
305
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
306
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
307
+
308
+ device = self._execution_device
309
+
310
+ do_classifier_free_guidance = guidance_scale > 1.0
311
+
312
+ # Prepare timesteps
313
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
314
+ timesteps = self.scheduler.timesteps
315
+
316
+ batch_size = 1
317
+
318
+ # Prepare clip image embeds
319
+ clip_image = self.clip_image_processor.preprocess(
320
+ ref_image, return_tensors="pt"
321
+ ).pixel_values
322
+ clip_image_embeds = self.image_encoder(
323
+ clip_image.to(device, dtype=self.image_encoder.dtype)
324
+ ).image_embeds
325
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
326
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
327
+
328
+ if do_classifier_free_guidance:
329
+ encoder_hidden_states = torch.cat(
330
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
331
+ )
332
+ reference_control_writer = ReferenceAttentionControl(
333
+ self.reference_unet,
334
+ do_classifier_free_guidance=do_classifier_free_guidance,
335
+ mode="write",
336
+ batch_size=batch_size,
337
+ fusion_blocks="full",
338
+ )
339
+ reference_control_reader = ReferenceAttentionControl(
340
+ self.denoising_unet,
341
+ do_classifier_free_guidance=do_classifier_free_guidance,
342
+ mode="read",
343
+ batch_size=batch_size,
344
+ fusion_blocks="full",
345
+ )
346
+
347
+ num_channels_latents = self.denoising_unet.in_channels
348
+ latents = self.prepare_latents(
349
+ batch_size * num_images_per_prompt,
350
+ num_channels_latents,
351
+ width,
352
+ height,
353
+ video_length,
354
+ clip_image_embeds.dtype,
355
+ device,
356
+ generator,
357
+ )
358
+
359
+ # Prepare extra step kwargs.
360
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
361
+
362
+ # Prepare ref image latents
363
+ ref_image_tensor = self.ref_image_processor.preprocess(
364
+ ref_image, height=height, width=width
365
+ ) # (bs, c, width, height)
366
+ ref_image_tensor = ref_image_tensor.to(
367
+ dtype=self.vae.dtype, device=self.vae.device
368
+ )
369
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
370
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
371
+
372
+ # Prepare a list of pose condition images
373
+ pose_cond_tensor_list = []
374
+ for pose_image in pose_images:
375
+ pose_cond_tensor = (
376
+ torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0
377
+ )
378
+ pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze(
379
+ 1
380
+ ) # (c, 1, h, w)
381
+ pose_cond_tensor_list.append(pose_cond_tensor)
382
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
383
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(0)
384
+ pose_cond_tensor = pose_cond_tensor.to(
385
+ device=device, dtype=self.pose_guider.dtype
386
+ )
387
+ pose_fea = self.pose_guider(pose_cond_tensor)
388
+ pose_fea = (
389
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
390
+ )
391
+
392
+ # denoising loop
393
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
394
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
395
+ for i, t in enumerate(timesteps):
396
+ # 1. Forward reference image
397
+ if i == 0:
398
+ self.reference_unet(
399
+ ref_image_latents.repeat(
400
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
401
+ ),
402
+ torch.zeros_like(t),
403
+ # t,
404
+ encoder_hidden_states=encoder_hidden_states,
405
+ return_dict=False,
406
+ )
407
+ reference_control_reader.update(reference_control_writer)
408
+
409
+ # 3.1 expand the latents if we are doing classifier free guidance
410
+ latent_model_input = (
411
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
412
+ )
413
+ latent_model_input = self.scheduler.scale_model_input(
414
+ latent_model_input, t
415
+ )
416
+
417
+ noise_pred = self.denoising_unet(
418
+ latent_model_input,
419
+ t,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ pose_cond_fea=pose_fea,
422
+ return_dict=False,
423
+ )[0]
424
+
425
+ # perform guidance
426
+ if do_classifier_free_guidance:
427
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
428
+ noise_pred = noise_pred_uncond + guidance_scale * (
429
+ noise_pred_text - noise_pred_uncond
430
+ )
431
+
432
+ # call the callback, if provided
433
+ if i == len(timesteps) - 1 or (
434
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
435
+ ):
436
+ progress_bar.update()
437
+ if callback is not None and i % callback_steps == 0:
438
+ step_idx = i // getattr(self.scheduler, "order", 1)
439
+ callback(step_idx, t, latents)
440
+
441
+ reference_control_reader.clear()
442
+ reference_control_writer.clear()
443
+
444
+ # Post-processing
445
+ images = self.decode_latents(latents) # (b, c, f, h, w)
446
+
447
+ # Convert to tensor
448
+ if output_type == "tensor":
449
+ images = torch.from_numpy(images)
450
+
451
+ if not return_dict:
452
+ return images
453
+
454
+ return Pose2VideoPipelineOutput(videos=images)
src/pipelines/pipeline_pose2vid_long.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py
2
+ import inspect
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import DiffusionPipeline
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.schedulers import (
12
+ DDIMScheduler,
13
+ DPMSolverMultistepScheduler,
14
+ EulerAncestralDiscreteScheduler,
15
+ EulerDiscreteScheduler,
16
+ LMSDiscreteScheduler,
17
+ PNDMScheduler,
18
+ )
19
+ from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from einops import rearrange
22
+ from tqdm import tqdm
23
+ from transformers import CLIPImageProcessor
24
+
25
+ from src.models.mutual_self_attention import ReferenceAttentionControl
26
+ from src.pipelines.context import get_context_scheduler
27
+ from src.pipelines.utils import get_tensor_interpolation_method
28
+
29
+
30
+ @dataclass
31
+ class Pose2VideoPipelineOutput(BaseOutput):
32
+ videos: Union[torch.Tensor, np.ndarray]
33
+
34
+
35
+ class Pose2VideoPipeline(DiffusionPipeline):
36
+ _optional_components = []
37
+
38
+ def __init__(
39
+ self,
40
+ vae,
41
+ image_encoder,
42
+ reference_unet,
43
+ denoising_unet,
44
+ pose_guider,
45
+ scheduler: Union[
46
+ DDIMScheduler,
47
+ PNDMScheduler,
48
+ LMSDiscreteScheduler,
49
+ EulerDiscreteScheduler,
50
+ EulerAncestralDiscreteScheduler,
51
+ DPMSolverMultistepScheduler,
52
+ ],
53
+ image_proj_model=None,
54
+ tokenizer=None,
55
+ text_encoder=None,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.register_modules(
60
+ vae=vae,
61
+ image_encoder=image_encoder,
62
+ reference_unet=reference_unet,
63
+ denoising_unet=denoising_unet,
64
+ pose_guider=pose_guider,
65
+ scheduler=scheduler,
66
+ image_proj_model=image_proj_model,
67
+ tokenizer=tokenizer,
68
+ text_encoder=text_encoder,
69
+ )
70
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
71
+ self.clip_image_processor = CLIPImageProcessor()
72
+ self.ref_image_processor = VaeImageProcessor(
73
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
74
+ )
75
+ self.cond_image_processor = VaeImageProcessor(
76
+ vae_scale_factor=self.vae_scale_factor,
77
+ do_convert_rgb=True,
78
+ do_normalize=False,
79
+ )
80
+
81
+ def enable_vae_slicing(self):
82
+ self.vae.enable_slicing()
83
+
84
+ def disable_vae_slicing(self):
85
+ self.vae.disable_slicing()
86
+
87
+ def enable_sequential_cpu_offload(self, gpu_id=0):
88
+ if is_accelerate_available():
89
+ from accelerate import cpu_offload
90
+ else:
91
+ raise ImportError("Please install accelerate via `pip install accelerate`")
92
+
93
+ device = torch.device(f"cuda:{gpu_id}")
94
+
95
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
96
+ if cpu_offloaded_model is not None:
97
+ cpu_offload(cpu_offloaded_model, device)
98
+
99
+ @property
100
+ def _execution_device(self):
101
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
102
+ return self.device
103
+ for module in self.unet.modules():
104
+ if (
105
+ hasattr(module, "_hf_hook")
106
+ and hasattr(module._hf_hook, "execution_device")
107
+ and module._hf_hook.execution_device is not None
108
+ ):
109
+ return torch.device(module._hf_hook.execution_device)
110
+ return self.device
111
+
112
+ def decode_latents(self, latents):
113
+ video_length = latents.shape[2]
114
+ latents = 1 / 0.18215 * latents
115
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
116
+ # video = self.vae.decode(latents).sample
117
+ video = []
118
+ for frame_idx in tqdm(range(latents.shape[0])):
119
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
120
+ video = torch.cat(video)
121
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
122
+ video = (video / 2 + 0.5).clamp(0, 1)
123
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
124
+ video = video.cpu().float().numpy()
125
+ return video
126
+
127
+ def prepare_extra_step_kwargs(self, generator, eta):
128
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
129
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
130
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
131
+ # and should be between [0, 1]
132
+
133
+ accepts_eta = "eta" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ extra_step_kwargs = {}
137
+ if accepts_eta:
138
+ extra_step_kwargs["eta"] = eta
139
+
140
+ # check if the scheduler accepts generator
141
+ accepts_generator = "generator" in set(
142
+ inspect.signature(self.scheduler.step).parameters.keys()
143
+ )
144
+ if accepts_generator:
145
+ extra_step_kwargs["generator"] = generator
146
+ return extra_step_kwargs
147
+
148
+ def prepare_latents(
149
+ self,
150
+ batch_size,
151
+ num_channels_latents,
152
+ width,
153
+ height,
154
+ video_length,
155
+ dtype,
156
+ device,
157
+ generator,
158
+ latents=None,
159
+ ):
160
+ shape = (
161
+ batch_size,
162
+ num_channels_latents,
163
+ video_length,
164
+ height // self.vae_scale_factor,
165
+ width // self.vae_scale_factor,
166
+ )
167
+ if isinstance(generator, list) and len(generator) != batch_size:
168
+ raise ValueError(
169
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
170
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
171
+ )
172
+
173
+ if latents is None:
174
+ latents = randn_tensor(
175
+ shape, generator=generator, device=device, dtype=dtype
176
+ )
177
+ else:
178
+ latents = latents.to(device)
179
+
180
+ # scale the initial noise by the standard deviation required by the scheduler
181
+ latents = latents * self.scheduler.init_noise_sigma
182
+ return latents
183
+
184
+ def _encode_prompt(
185
+ self,
186
+ prompt,
187
+ device,
188
+ num_videos_per_prompt,
189
+ do_classifier_free_guidance,
190
+ negative_prompt,
191
+ ):
192
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
193
+
194
+ text_inputs = self.tokenizer(
195
+ prompt,
196
+ padding="max_length",
197
+ max_length=self.tokenizer.model_max_length,
198
+ truncation=True,
199
+ return_tensors="pt",
200
+ )
201
+ text_input_ids = text_inputs.input_ids
202
+ untruncated_ids = self.tokenizer(
203
+ prompt, padding="longest", return_tensors="pt"
204
+ ).input_ids
205
+
206
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
207
+ text_input_ids, untruncated_ids
208
+ ):
209
+ removed_text = self.tokenizer.batch_decode(
210
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
211
+ )
212
+
213
+ if (
214
+ hasattr(self.text_encoder.config, "use_attention_mask")
215
+ and self.text_encoder.config.use_attention_mask
216
+ ):
217
+ attention_mask = text_inputs.attention_mask.to(device)
218
+ else:
219
+ attention_mask = None
220
+
221
+ text_embeddings = self.text_encoder(
222
+ text_input_ids.to(device),
223
+ attention_mask=attention_mask,
224
+ )
225
+ text_embeddings = text_embeddings[0]
226
+
227
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
228
+ bs_embed, seq_len, _ = text_embeddings.shape
229
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
230
+ text_embeddings = text_embeddings.view(
231
+ bs_embed * num_videos_per_prompt, seq_len, -1
232
+ )
233
+
234
+ # get unconditional embeddings for classifier free guidance
235
+ if do_classifier_free_guidance:
236
+ uncond_tokens: List[str]
237
+ if negative_prompt is None:
238
+ uncond_tokens = [""] * batch_size
239
+ elif type(prompt) is not type(negative_prompt):
240
+ raise TypeError(
241
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
242
+ f" {type(prompt)}."
243
+ )
244
+ elif isinstance(negative_prompt, str):
245
+ uncond_tokens = [negative_prompt]
246
+ elif batch_size != len(negative_prompt):
247
+ raise ValueError(
248
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
249
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
250
+ " the batch size of `prompt`."
251
+ )
252
+ else:
253
+ uncond_tokens = negative_prompt
254
+
255
+ max_length = text_input_ids.shape[-1]
256
+ uncond_input = self.tokenizer(
257
+ uncond_tokens,
258
+ padding="max_length",
259
+ max_length=max_length,
260
+ truncation=True,
261
+ return_tensors="pt",
262
+ )
263
+
264
+ if (
265
+ hasattr(self.text_encoder.config, "use_attention_mask")
266
+ and self.text_encoder.config.use_attention_mask
267
+ ):
268
+ attention_mask = uncond_input.attention_mask.to(device)
269
+ else:
270
+ attention_mask = None
271
+
272
+ uncond_embeddings = self.text_encoder(
273
+ uncond_input.input_ids.to(device),
274
+ attention_mask=attention_mask,
275
+ )
276
+ uncond_embeddings = uncond_embeddings[0]
277
+
278
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
279
+ seq_len = uncond_embeddings.shape[1]
280
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
281
+ uncond_embeddings = uncond_embeddings.view(
282
+ batch_size * num_videos_per_prompt, seq_len, -1
283
+ )
284
+
285
+ # For classifier free guidance, we need to do two forward passes.
286
+ # Here we concatenate the unconditional and text embeddings into a single batch
287
+ # to avoid doing two forward passes
288
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
289
+
290
+ return text_embeddings
291
+
292
+ def interpolate_latents(
293
+ self, latents: torch.Tensor, interpolation_factor: int, device
294
+ ):
295
+ if interpolation_factor < 2:
296
+ return latents
297
+
298
+ new_latents = torch.zeros(
299
+ (
300
+ latents.shape[0],
301
+ latents.shape[1],
302
+ ((latents.shape[2] - 1) * interpolation_factor) + 1,
303
+ latents.shape[3],
304
+ latents.shape[4],
305
+ ),
306
+ device=latents.device,
307
+ dtype=latents.dtype,
308
+ )
309
+
310
+ org_video_length = latents.shape[2]
311
+ rate = [i / interpolation_factor for i in range(interpolation_factor)][1:]
312
+
313
+ new_index = 0
314
+
315
+ v0 = None
316
+ v1 = None
317
+
318
+ for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]):
319
+ v0 = latents[:, :, i0, :, :]
320
+ v1 = latents[:, :, i1, :, :]
321
+
322
+ new_latents[:, :, new_index, :, :] = v0
323
+ new_index += 1
324
+
325
+ for f in rate:
326
+ v = get_tensor_interpolation_method()(
327
+ v0.to(device=device), v1.to(device=device), f
328
+ )
329
+ new_latents[:, :, new_index, :, :] = v.to(latents.device)
330
+ new_index += 1
331
+
332
+ new_latents[:, :, new_index, :, :] = v1
333
+ new_index += 1
334
+
335
+ return new_latents
336
+
337
+ @torch.no_grad()
338
+ def __call__(
339
+ self,
340
+ ref_image,
341
+ pose_images,
342
+ width,
343
+ height,
344
+ video_length,
345
+ num_inference_steps,
346
+ guidance_scale,
347
+ num_images_per_prompt=1,
348
+ eta: float = 0.0,
349
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
350
+ output_type: Optional[str] = "tensor",
351
+ return_dict: bool = True,
352
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
353
+ callback_steps: Optional[int] = 1,
354
+ context_schedule="uniform",
355
+ context_frames=24,
356
+ context_stride=1,
357
+ context_overlap=4,
358
+ context_batch_size=1,
359
+ interpolation_factor=1,
360
+ **kwargs,
361
+ ):
362
+ # Default height and width to unet
363
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
364
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
365
+
366
+ device = self._execution_device
367
+
368
+ do_classifier_free_guidance = guidance_scale > 1.0
369
+
370
+ # Prepare timesteps
371
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
372
+ timesteps = self.scheduler.timesteps
373
+
374
+ batch_size = 1
375
+
376
+ # Prepare clip image embeds
377
+ clip_image = self.clip_image_processor.preprocess(
378
+ ref_image.resize((224, 224)), return_tensors="pt"
379
+ ).pixel_values
380
+ clip_image_embeds = self.image_encoder(
381
+ clip_image.to(device, dtype=self.image_encoder.dtype)
382
+ ).image_embeds
383
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
384
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
385
+
386
+ if do_classifier_free_guidance:
387
+ encoder_hidden_states = torch.cat(
388
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
389
+ )
390
+
391
+ reference_control_writer = ReferenceAttentionControl(
392
+ self.reference_unet,
393
+ do_classifier_free_guidance=do_classifier_free_guidance,
394
+ mode="write",
395
+ batch_size=batch_size,
396
+ fusion_blocks="full",
397
+ )
398
+ reference_control_reader = ReferenceAttentionControl(
399
+ self.denoising_unet,
400
+ do_classifier_free_guidance=do_classifier_free_guidance,
401
+ mode="read",
402
+ batch_size=batch_size,
403
+ fusion_blocks="full",
404
+ )
405
+
406
+ num_channels_latents = self.denoising_unet.in_channels
407
+ latents = self.prepare_latents(
408
+ batch_size * num_images_per_prompt,
409
+ num_channels_latents,
410
+ width,
411
+ height,
412
+ video_length,
413
+ clip_image_embeds.dtype,
414
+ device,
415
+ generator,
416
+ )
417
+
418
+ # Prepare extra step kwargs.
419
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
420
+
421
+ # Prepare ref image latents
422
+ ref_image_tensor = self.ref_image_processor.preprocess(
423
+ ref_image, height=height, width=width
424
+ ) # (bs, c, width, height)
425
+ ref_image_tensor = ref_image_tensor.to(
426
+ dtype=self.vae.dtype, device=self.vae.device
427
+ )
428
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
429
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
430
+
431
+ # Prepare a list of pose condition images
432
+ pose_cond_tensor_list = []
433
+ for pose_image in pose_images:
434
+ pose_cond_tensor = self.cond_image_processor.preprocess(
435
+ pose_image, height=height, width=width
436
+ )
437
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
438
+ pose_cond_tensor_list.append(pose_cond_tensor)
439
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=2) # (bs, c, t, h, w)
440
+ pose_cond_tensor = pose_cond_tensor.to(
441
+ device=device, dtype=self.pose_guider.dtype
442
+ )
443
+ pose_fea = self.pose_guider(pose_cond_tensor)
444
+
445
+ context_scheduler = get_context_scheduler(context_schedule)
446
+
447
+ # denoising loop
448
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
449
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
450
+ for i, t in enumerate(timesteps):
451
+ noise_pred = torch.zeros(
452
+ (
453
+ latents.shape[0] * (2 if do_classifier_free_guidance else 1),
454
+ *latents.shape[1:],
455
+ ),
456
+ device=latents.device,
457
+ dtype=latents.dtype,
458
+ )
459
+ counter = torch.zeros(
460
+ (1, 1, latents.shape[2], 1, 1),
461
+ device=latents.device,
462
+ dtype=latents.dtype,
463
+ )
464
+
465
+ # 1. Forward reference image
466
+ if i == 0:
467
+ self.reference_unet(
468
+ ref_image_latents.repeat(
469
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
470
+ ),
471
+ torch.zeros_like(t),
472
+ # t,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ return_dict=False,
475
+ )
476
+ reference_control_reader.update(reference_control_writer)
477
+
478
+ context_queue = list(
479
+ context_scheduler(
480
+ 0,
481
+ num_inference_steps,
482
+ latents.shape[2],
483
+ context_frames,
484
+ context_stride,
485
+ 0,
486
+ )
487
+ )
488
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
489
+
490
+ context_queue = list(
491
+ context_scheduler(
492
+ 0,
493
+ num_inference_steps,
494
+ latents.shape[2],
495
+ context_frames,
496
+ context_stride,
497
+ context_overlap,
498
+ )
499
+ )
500
+
501
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
502
+ global_context = []
503
+ for i in range(num_context_batches):
504
+ global_context.append(
505
+ context_queue[
506
+ i * context_batch_size : (i + 1) * context_batch_size
507
+ ]
508
+ )
509
+
510
+ for context in global_context:
511
+ # 3.1 expand the latents if we are doing classifier free guidance
512
+ latent_model_input = (
513
+ torch.cat([latents[:, :, c] for c in context])
514
+ .to(device)
515
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
516
+ )
517
+ latent_model_input = self.scheduler.scale_model_input(
518
+ latent_model_input, t
519
+ )
520
+ b, c, f, h, w = latent_model_input.shape
521
+ latent_pose_input = torch.cat(
522
+ [pose_fea[:, :, c] for c in context]
523
+ ).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
524
+
525
+ pred = self.denoising_unet(
526
+ latent_model_input,
527
+ t,
528
+ encoder_hidden_states=encoder_hidden_states[:b],
529
+ pose_cond_fea=latent_pose_input,
530
+ return_dict=False,
531
+ )[0]
532
+
533
+ for j, c in enumerate(context):
534
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred
535
+ counter[:, :, c] = counter[:, :, c] + 1
536
+
537
+ # perform guidance
538
+ if do_classifier_free_guidance:
539
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
540
+ noise_pred = noise_pred_uncond + guidance_scale * (
541
+ noise_pred_text - noise_pred_uncond
542
+ )
543
+
544
+ latents = self.scheduler.step(
545
+ noise_pred, t, latents, **extra_step_kwargs
546
+ ).prev_sample
547
+
548
+ if i == len(timesteps) - 1 or (
549
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
550
+ ):
551
+ progress_bar.update()
552
+ if callback is not None and i % callback_steps == 0:
553
+ step_idx = i // getattr(self.scheduler, "order", 1)
554
+ callback(step_idx, t, latents)
555
+
556
+ reference_control_reader.clear()
557
+ reference_control_writer.clear()
558
+
559
+ if interpolation_factor > 0:
560
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
561
+ # Post-processing
562
+ images = self.decode_latents(latents) # (b, c, f, h, w)
563
+
564
+ # Convert to tensor
565
+ if output_type == "tensor":
566
+ images = torch.from_numpy(images)
567
+
568
+ if not return_dict:
569
+ return images
570
+
571
+ return Pose2VideoPipelineOutput(videos=images)
src/pipelines/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ tensor_interpolation = None
4
+
5
+
6
+ def get_tensor_interpolation_method():
7
+ return tensor_interpolation
8
+
9
+
10
+ def set_tensor_interpolation_method(is_slerp):
11
+ global tensor_interpolation
12
+ tensor_interpolation = slerp if is_slerp else linear
13
+
14
+
15
+ def linear(v1, v2, t):
16
+ return (1.0 - t) * v1 + t * v2
17
+
18
+
19
+ def slerp(
20
+ v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21
+ ) -> torch.Tensor:
22
+ u0 = v0 / v0.norm()
23
+ u1 = v1 / v1.norm()
24
+ dot = (u0 * u1).sum()
25
+ if dot.abs() > DOT_THRESHOLD:
26
+ # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27
+ return (1.0 - t) * v0 + t * v1
28
+ omega = dot.acos()
29
+ return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
src/utils/util.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import os.path as osp
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import av
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ from einops import rearrange
12
+ from PIL import Image
13
+
14
+
15
+ def seed_everything(seed):
16
+ import random
17
+
18
+ import numpy as np
19
+
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ np.random.seed(seed % (2**32))
23
+ random.seed(seed)
24
+
25
+
26
+ def import_filename(filename):
27
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
28
+ module = importlib.util.module_from_spec(spec)
29
+ sys.modules[spec.name] = module
30
+ spec.loader.exec_module(module)
31
+ return module
32
+
33
+
34
+ def save_videos_from_pil(pil_images, path, fps=8):
35
+ import av
36
+
37
+ save_fmt = Path(path).suffix
38
+ os.makedirs(os.path.dirname(path), exist_ok=True)
39
+ width, height = pil_images[0].size
40
+
41
+ if save_fmt == ".mp4":
42
+ codec = "libx264"
43
+ container = av.open(path, "w")
44
+ stream = container.add_stream(codec, rate=fps)
45
+
46
+ stream.width = width
47
+ stream.height = height
48
+
49
+ for pil_image in pil_images:
50
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
51
+ av_frame = av.VideoFrame.from_image(pil_image)
52
+ container.mux(stream.encode(av_frame))
53
+ container.mux(stream.encode())
54
+ container.close()
55
+
56
+ elif save_fmt == ".gif":
57
+ pil_images[0].save(
58
+ fp=path,
59
+ format="GIF",
60
+ append_images=pil_images[1:],
61
+ save_all=True,
62
+ duration=(1 / fps * 1000),
63
+ loop=0,
64
+ )
65
+ else:
66
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
67
+
68
+
69
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
70
+ videos = rearrange(videos, "b c t h w -> t b c h w")
71
+ height, width = videos.shape[-2:]
72
+ outputs = []
73
+
74
+ for x in videos:
75
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
76
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
77
+ if rescale:
78
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
79
+ x = (x * 255).numpy().astype(np.uint8)
80
+ x = Image.fromarray(x)
81
+
82
+ outputs.append(x)
83
+
84
+ os.makedirs(os.path.dirname(path), exist_ok=True)
85
+
86
+ save_videos_from_pil(outputs, path, fps)
87
+
88
+
89
+ def read_frames(video_path):
90
+ container = av.open(video_path)
91
+
92
+ video_stream = next(s for s in container.streams if s.type == "video")
93
+ frames = []
94
+ for packet in container.demux(video_stream):
95
+ for frame in packet.decode():
96
+ image = Image.frombytes(
97
+ "RGB",
98
+ (frame.width, frame.height),
99
+ frame.to_rgb().to_ndarray(),
100
+ )
101
+ frames.append(image)
102
+
103
+ return frames
104
+
105
+
106
+ def get_fps(video_path):
107
+ container = av.open(video_path)
108
+ video_stream = next(s for s in container.streams if s.type == "video")
109
+ fps = video_stream.average_rate
110
+ container.close()
111
+ return fps
tools/vid2pose.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.dwpose import DWposeDetector
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from src.utils.util import get_fps, read_frames, save_videos_from_pil
6
+ import numpy as np
7
+
8
+
9
+ if __name__ == "__main__":
10
+ import argparse
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--video_path", type=str)
14
+ args = parser.parse_args()
15
+
16
+ if not os.path.exists(args.video_path):
17
+ raise ValueError(f"Path: {args.video_path} not exists")
18
+
19
+ dir_path, video_name = (
20
+ os.path.dirname(args.video_path),
21
+ os.path.splitext(os.path.basename(args.video_path))[0],
22
+ )
23
+ out_path = os.path.join(dir_path, video_name + "_kps.mp4")
24
+
25
+ detector = DWposeDetector()
26
+ detector = detector.to(f"cuda")
27
+
28
+ fps = get_fps(args.video_path)
29
+ frames = read_frames(args.video_path)
30
+ kps_results = []
31
+ for i, frame_pil in enumerate(frames):
32
+ result, score = detector(frame_pil)
33
+ score = np.mean(score, axis=-1)
34
+
35
+ kps_results.append(result)
36
+
37
+ print(out_path)
38
+ save_videos_from_pil(kps_results, out_path, fps=fps)