rahul7star commited on
Commit
0de2bab
·
verified ·
1 Parent(s): 3e06c3c

Update app_lora.py

Browse files
Files changed (1) hide show
  1. app_lora.py +22 -17
app_lora.py CHANGED
@@ -36,17 +36,8 @@ DEFAULT_NAG_NEGATIVE_PROMPT = "Static, motionless, still, ugly, bad quality, wor
36
  MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
37
  SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
38
  SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
39
-
40
-
41
-
42
- #LORA_REPO_ID = "Kijai/WanVideo_comfy"
43
- #LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
44
-
45
-
46
- # new experiment for future work
47
-
48
- LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
49
- LORA_FILENAME = "FusionX_LoRa/Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
50
 
51
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
52
  wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME)
@@ -57,9 +48,15 @@ pipe = NAGWanPipeline.from_pretrained(
57
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
58
  pipe.to("cuda")
59
 
60
-
61
-
62
-
 
 
 
 
 
 
63
 
64
  pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
65
  pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
@@ -134,6 +131,10 @@ def generate_video(
134
  else:
135
  baseline_video_path = None
136
 
 
 
 
 
137
  return nag_video_path, baseline_video_path, current_seed
138
 
139
 
@@ -150,15 +151,19 @@ def generate_video_with_example(
150
  seed=DEFAULT_SEED, randomize_seed=False,
151
  compare=True,
152
  )
 
 
 
153
  return nag_video_path, baseline_video_path, \
154
  DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, \
155
  DEFAULT_DURATION_SECONDS, DEFAULT_STEPS, seed, True
156
 
157
 
158
  with gr.Blocks() as demo:
159
- gr.Markdown('''# Normalized Attention Guidance (NAG) for fast 4 steps Wan2.1-T2V-14B with FuxionX Base T2V
160
  Implementation of [Normalized Attention Guidance](https://chendaryen.github.io/NAG.github.io/).
161
-
 
162
  ''')
163
 
164
  with gr.Row():
@@ -229,4 +234,4 @@ with gr.Blocks() as demo:
229
  )
230
 
231
  if __name__ == "__main__":
232
- demo.queue().launch()
 
36
  MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
37
  SUB_MODEL_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
38
  SUB_MODEL_FILENAME = "Wan14BT2VFusioniX_fp16_.safetensors"
39
+ LORA_REPO_ID = "Kijai/WanVideo_comfy"
40
+ LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
 
 
 
 
 
 
 
 
 
41
 
42
  vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32)
43
  wan_path = hf_hub_download(repo_id=SUB_MODEL_ID, filename=SUB_MODEL_FILENAME)
 
48
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
49
  pipe.to("cuda")
50
 
51
+ causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
52
+ pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
53
+ pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
54
+ for name, param in pipe.transformer.named_parameters():
55
+ if "lora_B" in name:
56
+ if "blocks.0" in name:
57
+ param.data = param.data * 0.25
58
+ pipe.fuse_lora()
59
+ pipe.unload_lora_weights()
60
 
61
  pipe.transformer.__class__.attn_processors = NagWanTransformer3DModel.attn_processors
62
  pipe.transformer.__class__.set_attn_processor = NagWanTransformer3DModel.set_attn_processor
 
131
  else:
132
  baseline_video_path = None
133
 
134
+ if torch.cuda.is_available():
135
+ print("Allocated:", torch.cuda.memory_allocated() / 1024**2, "MB")
136
+ print("Cached: ", torch.cuda.memory_reserved() / 1024**2, "MB")
137
+
138
  return nag_video_path, baseline_video_path, current_seed
139
 
140
 
 
151
  seed=DEFAULT_SEED, randomize_seed=False,
152
  compare=True,
153
  )
154
+ if torch.cuda.is_available():
155
+ print("Allocated:", torch.cuda.memory_allocated() / 1024**2, "MB")
156
+ print("Cached: ", torch.cuda.memory_reserved() / 1024**2, "MB")
157
  return nag_video_path, baseline_video_path, \
158
  DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE, \
159
  DEFAULT_DURATION_SECONDS, DEFAULT_STEPS, seed, True
160
 
161
 
162
  with gr.Blocks() as demo:
163
+ gr.Markdown('''# Normalized Attention Guidance (NAG) for fast 4 steps Wan2.1-T2V-14B with CausVid LoRA
164
  Implementation of [Normalized Attention Guidance](https://chendaryen.github.io/NAG.github.io/).
165
+
166
+ [CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan2.1 to run faster in just 4-8 steps, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors).
167
  ''')
168
 
169
  with gr.Row():
 
234
  )
235
 
236
  if __name__ == "__main__":
237
+ demo.queue().launch()