Bton commited on
Commit
cb588cb
·
verified ·
1 Parent(s): a938a5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -126
app.py CHANGED
@@ -1,154 +1,181 @@
1
  import gradio as gr
2
- import numpy as np
 
3
  import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
  import torch
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
 
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
 
 
 
 
 
 
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
  prompt=prompt,
43
- negative_prompt=negative_prompt,
44
  guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
  width=width,
47
  height=height,
48
  generator=generator,
49
  ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
 
 
72
  prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
  )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
 
 
 
 
 
 
 
 
 
 
119
  with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
  ],
150
- outputs=[result, seed],
 
151
  )
152
 
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ import os
4
  import random
5
+ import uuid
6
+ from datetime import datetime
7
  from diffusers import DiffusionPipeline
8
  import torch
9
+ import numpy as np
10
+ from PIL import Image
11
 
12
+ NUM_INFERENCE_STEPS = 8
 
 
 
 
 
 
 
 
 
13
 
14
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
15
+ # Constants
16
  MAX_SEED = np.iinfo(np.int32).max
 
 
17
 
18
+ # Create permanent storage directory for Flux generated images
19
+ SAVE_DIR = "saved_images"
20
+ if not os.path.exists(SAVE_DIR):
21
+ os.makedirs(SAVE_DIR, exist_ok=True)
22
+
23
+ def get_seed(randomize_seed: bool, seed: int) -> int:
24
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
25
+
26
+ @spaces.GPU
27
+ def generate_flux_image(
28
+ prompt: str,
29
+ seed: int,
30
+ randomize_seed: bool,
31
+ width: int,
32
+ height: int,
33
+ guidance_scale: float,
34
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
35
+ ) -> Image.Image:
36
+ """Generate image using Flux pipeline"""
37
  if randomize_seed:
38
  seed = random.randint(0, MAX_SEED)
39
+ generator = torch.Generator(device=device).manual_seed(seed)
40
+ prompt = "wbgmsst, " + prompt + ", 3D isometric, white background"
41
+ image = flux_pipeline(
 
42
  prompt=prompt,
 
43
  guidance_scale=guidance_scale,
44
+ num_inference_steps=NUM_INFERENCE_STEPS,
45
  width=width,
46
  height=height,
47
  generator=generator,
48
  ).images[0]
49
+
50
+ # Save the generated image
51
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
52
+ unique_id = str(uuid.uuid4())[:8]
53
+ filename = f"{timestamp}_{unique_id}.png"
54
+ filepath = os.path.join(SAVE_DIR, filename)
55
+ image.save(filepath)
56
+
57
+ return image, seed, filepath
58
+
59
+ # Gradio Interface
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("""
62
+ ## Game Asset Generation with FLUX
63
+ * Enter a prompt to generate a game asset image
64
+ * Images are automatically saved to the 'saved_images' directory
65
+ * [Flux-Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
66
+ * [Flux Game Assets LoRA](https://huggingface.co/gokaygokay/Flux-Game-Assets-LoRA-v2)
67
+ * [Hyper FLUX 8Steps LoRA](https://huggingface.co/ByteDance/Hyper-SD)
68
+ """)
69
+
70
+ with gr.Row():
71
+ with gr.Column():
72
+ # Flux image generation inputs
73
  prompt = gr.Text(
74
+ label="Prompt",
75
+ placeholder="Enter your game asset description",
76
+ lines=3
 
 
77
  )
78
+
79
+ with gr.Accordion("Generation Settings", open=True):
80
+ seed = gr.Slider(
81
+ minimum=0,
82
+ maximum=MAX_SEED,
83
+ label="Seed",
84
+ value=42,
85
+ step=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
87
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
88
+
89
+ with gr.Row():
90
+ width = gr.Slider(
91
+ minimum=512,
92
+ maximum=1024,
93
+ label="Width",
94
+ value=1024,
95
+ step=16
96
+ )
97
+ height = gr.Slider(
98
+ minimum=512,
99
+ maximum=1024,
100
+ label="Height",
101
+ value=1024,
102
+ step=16
103
+ )
104
+
105
+ guidance_scale = gr.Slider(
106
+ minimum=0.0,
107
+ maximum=10.0,
108
+ label="Guidance Scale",
109
+ value=3.5,
110
+ step=0.1
111
  )
112
 
113
+ generate_btn = gr.Button("Generate", variant="primary")
114
+
115
+ with gr.Column():
116
+ generated_image = gr.Image(
117
+ label="Generated Asset",
118
+ type="pil",
119
+ interactive=False
120
+ )
121
+
122
  with gr.Row():
123
+ seed_output = gr.Number(label="Seed Used", interactive=False)
124
+ file_path = gr.Text(label="Saved To", interactive=False)
125
+
126
+ download_btn = gr.DownloadButton(
127
+ label="Download Image",
128
+ visible=False
129
+ )
130
 
131
+ # Event handlers
132
+ generate_btn.click(
133
+ generate_flux_image,
134
+ inputs=[prompt, seed, randomize_seed, width, height, guidance_scale],
135
+ outputs=[generated_image, seed_output, file_path],
136
+ ).then(
137
+ lambda filepath: gr.DownloadButton(visible=True, value=filepath),
138
+ inputs=[file_path],
139
+ outputs=[download_btn]
140
+ )
141
 
142
+ # Examples
143
+ gr.Examples(
144
+ examples=[
145
+ ["medieval sword with glowing runes"],
146
+ ["wooden treasure chest with gold coins"],
147
+ ["health potion bottle with red liquid"],
148
+ ["stone castle tower"],
149
+ ["pixel art coin sprite"],
150
+ ["low poly tree"],
151
+ ["fantasy spell book"],
152
+ ["iron shield with dragon emblem"],
 
 
153
  ],
154
+ inputs=prompt,
155
+ label="Example Prompts"
156
  )
157
 
158
+ # Initialize Flux pipeline
159
  if __name__ == "__main__":
160
+ from diffusers import FluxTransformer2DModel, FluxPipeline, BitsAndBytesConfig, GGUFQuantizationConfig
161
+ from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
162
+
163
+ # Initialize Flux pipeline
164
+ device = "cuda" if torch.cuda.is_available() else "cpu"
165
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
166
+
167
+ dtype = torch.bfloat16
168
+ file_url = "https://huggingface.co/gokaygokay/flux-game/blob/main/hyperflux_00001_.q8_0.gguf"
169
+ file_url = file_url.replace("/resolve/main/", "/blob/main/").replace("?download=true", "")
170
+ single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
171
+ quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
172
+ text_encoder_2 = T5EncoderModel.from_pretrained(single_file_base_model, subfolder="text_encoder_2", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config_tf, token=huggingface_token)
173
+ if ".gguf" in file_url:
174
+ transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", quantization_config=GGUFQuantizationConfig(compute_dtype=dtype), torch_dtype=dtype, config=single_file_base_model)
175
+ else:
176
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, token=huggingface_token)
177
+ transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config, token=huggingface_token)
178
+ flux_pipeline = FluxPipeline.from_pretrained(single_file_base_model, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype, token=huggingface_token)
179
+ flux_pipeline.to("cuda")
180
+
181
+ demo.launch()