yeshog50 commited on
Commit
fad0145
·
verified ·
1 Parent(s): f121897

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -19
app.py CHANGED
@@ -2,24 +2,16 @@ import os
2
  import random
3
  import gradio as gr
4
  import torch
5
- from diffusers import DiffusionPipeline, UNet2DConditionModel
6
  from transformers import CLIPTextModel, CLIPTokenizer
7
 
8
  # Configuration - Using Flux Model
9
  MODEL_ID = "CompVis/Flux-Pro"
10
- LORA_ID = "flux/lora-weights"
11
  MODEL_CACHE = "model_cache"
12
  os.makedirs(MODEL_CACHE, exist_ok=True)
13
 
14
  def get_pipeline():
15
- # Load Flux components
16
- unet = UNet2DConditionModel.from_pretrained(
17
- MODEL_ID,
18
- subfolder="unet",
19
- cache_dir=MODEL_CACHE,
20
- torch_dtype=torch.float32
21
- )
22
-
23
  text_encoder = CLIPTextModel.from_pretrained(
24
  MODEL_ID,
25
  subfolder="text_encoder",
@@ -35,7 +27,6 @@ def get_pipeline():
35
  # Create pipeline
36
  pipe = DiffusionPipeline.from_pretrained(
37
  MODEL_ID,
38
- unet=unet,
39
  text_encoder=text_encoder,
40
  tokenizer=tokenizer,
41
  cache_dir=MODEL_CACHE,
@@ -43,14 +34,6 @@ def get_pipeline():
43
  safety_checker=None
44
  )
45
 
46
- # Load LoRA weights
47
- lora_path = hf_hub_download(
48
- LORA_ID,
49
- "flux_lora.safetensors",
50
- cache_dir=MODEL_CACHE
51
- )
52
- pipe.unet.load_attn_procs(lora_path)
53
-
54
  # CPU optimizations
55
  pipe = pipe.to("cpu")
56
  pipe.enable_attention_slicing()
 
2
  import random
3
  import gradio as gr
4
  import torch
5
+ from diffusers import DiffusionPipeline
6
  from transformers import CLIPTextModel, CLIPTokenizer
7
 
8
  # Configuration - Using Flux Model
9
  MODEL_ID = "CompVis/Flux-Pro"
 
10
  MODEL_CACHE = "model_cache"
11
  os.makedirs(MODEL_CACHE, exist_ok=True)
12
 
13
  def get_pipeline():
14
+ # Load Flux model components
 
 
 
 
 
 
 
15
  text_encoder = CLIPTextModel.from_pretrained(
16
  MODEL_ID,
17
  subfolder="text_encoder",
 
27
  # Create pipeline
28
  pipe = DiffusionPipeline.from_pretrained(
29
  MODEL_ID,
 
30
  text_encoder=text_encoder,
31
  tokenizer=tokenizer,
32
  cache_dir=MODEL_CACHE,
 
34
  safety_checker=None
35
  )
36
 
 
 
 
 
 
 
 
 
37
  # CPU optimizations
38
  pipe = pipe.to("cpu")
39
  pipe.enable_attention_slicing()