multimodalart HF Staff commited on
Commit
b4b0028
·
verified ·
1 Parent(s): 3979845

faster safety checking

Browse files
Files changed (1) hide show
  1. app.py +30 -1
app.py CHANGED
@@ -1,16 +1,45 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
  from diffusers import Cosmos2TextToImagePipeline, EDMEulerScheduler
 
5
  import random
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  model_14b_id = "nvidia/Cosmos-Predict2-14B-Text2Image"
8
 
9
  pipe_14b = Cosmos2TextToImagePipeline.from_pretrained(
10
  model_14b_id,
11
  torch_dtype=torch.bfloat16
12
  )
13
-
14
  pipe_14b.to("cuda")
15
 
16
  @spaces.GPU(duration=140)
 
1
+ import subprocess
2
+
3
+ subprocess.run(
4
+ "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True
5
+ )
6
+
7
  import gradio as gr
8
  import spaces
9
  import torch
10
  from diffusers import Cosmos2TextToImagePipeline, EDMEulerScheduler
11
+ from transformers import AutoModelForCausalLM, SiglipProcessor
12
  import random
13
 
14
+ #Add flash_attention_2 to the safeguard model
15
+ def patch_from_pretrained(cls):
16
+ orig_method = cls.from_pretrained
17
+
18
+ def new_from_pretrained(*args, **kwargs):
19
+ kwargs.setdefault("attn_implementation", "flash_attention_2")
20
+ kwargs.setdefault("torch_dtype", torch.bfloat16)
21
+ return orig_method(*args, **kwargs)
22
+
23
+ cls.from_pretrained = new_from_pretrained
24
+
25
+ patch_from_pretrained(AutoModelForCausalLM)
26
+
27
+ #Add a `use_fast` to the safeguard image processor
28
+ def patch_processor_fast(cls):
29
+ orig_method = cls.from_pretrained
30
+ def new_from_pretrained(*args, **kwargs):
31
+ kwargs.setdefault("use_fast", True)
32
+ return orig_method(*args, **kwargs)
33
+ cls.from_pretrained = new_from_pretrained
34
+
35
+ patch_processor_fast(SiglipProcessor)
36
+
37
  model_14b_id = "nvidia/Cosmos-Predict2-14B-Text2Image"
38
 
39
  pipe_14b = Cosmos2TextToImagePipeline.from_pretrained(
40
  model_14b_id,
41
  torch_dtype=torch.bfloat16
42
  )
 
43
  pipe_14b.to("cuda")
44
 
45
  @spaces.GPU(duration=140)