lunarflu HF Staff commited on
Commit
427bebd
·
1 Parent(s): 3391b6e

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (1) hide show
  1. deepfloydif.py +298 -0
deepfloydif.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import glob
3
+ import os
4
+ import pathlib
5
+ import random
6
+ import threading
7
+
8
+ import gradio as gr
9
+ import discord
10
+ from gradio_client import Client
11
+ from PIL import Image
12
+ from discord.ext import commands
13
+ from discord.ui import Button, View
14
+
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+ deepfloydif_client = Client("huggingface-projects/IF", HF_TOKEN)
17
+ DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")
18
+ intents = discord.Intents.all()
19
+ bot = commands.Bot(command_prefix="/", intents=intents)
20
+
21
+
22
+ @bot.event
23
+ async def on_ready():
24
+ print(f"Logged in as {bot.user} (ID: {bot.user.id})")
25
+ synced = await bot.tree.sync()
26
+ print(f"Synced commands: {', '.join([s.name for s in synced])}.")
27
+ print("------")
28
+
29
+
30
+ @bot.hybrid_command(
31
+ name="deepfloydif",
32
+ description="Enter a prompt to generate an image! Can generate realistic text, too!",
33
+ )
34
+ async def deepfloydif(ctx, prompt: str):
35
+ """DeepfloydIF stage 1 generation"""
36
+ try:
37
+ await deepfloydif_generate64(ctx, prompt)
38
+ except Exception as e:
39
+ print(f"Error: {e}")
40
+
41
+
42
+ def deepfloydif_generate64_inference(prompt):
43
+ """Generates four images based on a prompt"""
44
+ negative_prompt = ""
45
+ seed = random.randint(0, 1000)
46
+ number_of_images = 4
47
+ guidance_scale = 7
48
+ custom_timesteps_1 = "smart50"
49
+ number_of_inference_steps = 50
50
+ (
51
+ stage_1_images,
52
+ stage_1_param_path,
53
+ path_for_upscale256_upscaling,
54
+ ) = deepfloydif_client.predict(
55
+ prompt,
56
+ negative_prompt,
57
+ seed,
58
+ number_of_images,
59
+ guidance_scale,
60
+ custom_timesteps_1,
61
+ number_of_inference_steps,
62
+ api_name="/generate64",
63
+ )
64
+ return [stage_1_images, stage_1_param_path, path_for_upscale256_upscaling]
65
+
66
+
67
+ def deepfloydif_upscale256_inference(index, path_for_upscale256_upscaling):
68
+ """Upscales one of the images from deepfloydif_generate64_inference based on the chosen index"""
69
+ selected_index_for_upscale256 = index
70
+ seed_2 = 0
71
+ guidance_scale_2 = 4
72
+ custom_timesteps_2 = "smart50"
73
+ number_of_inference_steps_2 = 50
74
+ result_path = deepfloydif_client.predict(
75
+ path_for_upscale256_upscaling,
76
+ selected_index_for_upscale256,
77
+ seed_2,
78
+ guidance_scale_2,
79
+ custom_timesteps_2,
80
+ number_of_inference_steps_2,
81
+ api_name="/upscale256",
82
+ )
83
+ return result_path
84
+
85
+
86
+ def deepfloydif_upscale1024_inference(index, path_for_upscale256_upscaling, prompt):
87
+ """Upscales to stage 2, then stage 3"""
88
+ selected_index_for_upscale256 = index
89
+ seed_2 = 0 # default seed for stage 2 256 upscaling
90
+ guidance_scale_2 = 4 # default for stage 2
91
+ custom_timesteps_2 = "smart50" # default for stage 2
92
+ number_of_inference_steps_2 = 50 # default for stage 2
93
+ negative_prompt = "" # empty (not used, could add in the future)
94
+
95
+ seed_3 = 0 # default for stage 3 1024 upscaling
96
+ guidance_scale_3 = 9 # default for stage 3
97
+ number_of_inference_steps_3 = 40 # default for stage 3
98
+
99
+ result_path = deepfloydif_client.predict(
100
+ path_for_upscale256_upscaling,
101
+ selected_index_for_upscale256,
102
+ seed_2,
103
+ guidance_scale_2,
104
+ custom_timesteps_2,
105
+ number_of_inference_steps_2,
106
+ prompt,
107
+ negative_prompt,
108
+ seed_3,
109
+ guidance_scale_3,
110
+ number_of_inference_steps_3,
111
+ api_name="/upscale1024",
112
+ )
113
+ return result_path
114
+
115
+
116
+ def load_image(png_files, stage_1_images):
117
+ """Opens images as variables so we can combine them later"""
118
+ results = []
119
+ for file in png_files:
120
+ png_path = os.path.join(stage_1_images, file)
121
+ results.append(Image.open(png_path))
122
+ return results
123
+
124
+
125
+ def combine_images(png_files, stage_1_images, partial_path):
126
+ if os.environ.get("TEST_ENV") == "True":
127
+ print("Combining images for deepfloydif_generate64")
128
+ images = load_image(png_files, stage_1_images)
129
+ combined_image = Image.new("RGB", (images[0].width * 2, images[0].height * 2))
130
+ combined_image.paste(images[0], (0, 0))
131
+ combined_image.paste(images[1], (images[0].width, 0))
132
+ combined_image.paste(images[2], (0, images[0].height))
133
+ combined_image.paste(images[3], (images[0].width, images[0].height))
134
+ combined_image_path = os.path.join(stage_1_images, f"{partial_path}.png")
135
+ combined_image.save(combined_image_path)
136
+ return combined_image_path
137
+
138
+
139
+ async def deepfloydif_generate64(ctx, prompt):
140
+ """DeepfloydIF command (generate images with realistic text using slash commands)"""
141
+ try:
142
+ channel = ctx.channel
143
+ # interaction.response message can't be used to create a thread, so we create another message
144
+ message = await ctx.send(f"**{prompt}** - {ctx.author.mention} (generating...)")
145
+
146
+ loop = asyncio.get_running_loop()
147
+ result = await loop.run_in_executor(None, deepfloydif_generate64_inference, prompt)
148
+ stage_1_images = result[0]
149
+ path_for_upscale256_upscaling = result[2]
150
+
151
+ partial_path = pathlib.Path(path_for_upscale256_upscaling).name
152
+ png_files = list(glob.glob(f"{stage_1_images}/**/*.png"))
153
+
154
+ if png_files:
155
+ await message.delete()
156
+ combined_image_path = combine_images(png_files, stage_1_images, partial_path)
157
+ if os.environ.get("TEST_ENV") == "True":
158
+ print("Images combined for deepfloydif_generate64")
159
+
160
+ with Image.open(combined_image_path) as img:
161
+ width, height = img.size
162
+ new_width = width * 3
163
+ new_height = height * 3
164
+ resized_img = img.resize((new_width, new_height))
165
+ x2_combined_image_path = combined_image_path
166
+ resized_img.save(x2_combined_image_path)
167
+
168
+ # making image bigger, more readable
169
+ with open(x2_combined_image_path, "rb") as f: # was combined_image_path
170
+ button1 = Button(custom_id="0", emoji="↖")
171
+ button2 = Button(custom_id="1", emoji="↗")
172
+ button3 = Button(custom_id="2", emoji="↙")
173
+ button4 = Button(custom_id="3", emoji="↘")
174
+
175
+ async def button_callback(interaction):
176
+ index = int(interaction.data["custom_id"]) # 0,1,2,3
177
+
178
+ await interaction.response.send_message(
179
+ f"{interaction.user.mention} (upscaling...)", ephemeral=True
180
+ )
181
+ result_path = await deepfloydif_upscale256(index, path_for_upscale256_upscaling)
182
+
183
+ # create and use upscale 1024 button
184
+ with open(result_path, "rb") as f:
185
+ upscale1024 = Button(label="High-quality upscale (x4)", custom_id=str(index))
186
+ upscale1024.callback = upscale1024_callback
187
+ view = View(timeout=None)
188
+ view.add_item(upscale1024)
189
+
190
+ await interaction.delete_original_response()
191
+ await channel.send(
192
+ content=(
193
+ f"{interaction.user.mention} Here is the upscaled image! Click the button"
194
+ " to upscale even more!"
195
+ ),
196
+ file=discord.File(f, f"{prompt}.png"),
197
+ view=view,
198
+ )
199
+
200
+ async def upscale1024_callback(interaction):
201
+ index = int(interaction.data["custom_id"])
202
+
203
+ await interaction.response.send_message(
204
+ f"{interaction.user.mention} (upscaling...)", ephemeral=True
205
+ )
206
+ result_path = await deepfloydif_upscale1024(index, path_for_upscale256_upscaling, prompt)
207
+
208
+ with open(result_path, "rb") as f:
209
+ await interaction.delete_original_response()
210
+ await channel.send(
211
+ content=f"{interaction.user.mention} Here's your high-quality x16 image!",
212
+ file=discord.File(f, f"{prompt}.png"),
213
+ )
214
+
215
+ button1.callback = button_callback
216
+ button2.callback = button_callback
217
+ button3.callback = button_callback
218
+ button4.callback = button_callback
219
+
220
+ view = View(timeout=None)
221
+ view.add_item(button1)
222
+ view.add_item(button2)
223
+ view.add_item(button3)
224
+ view.add_item(button4)
225
+
226
+ # could store this message as combined_image_dfif in case it's useful for future testing
227
+ await channel.send(
228
+ f"**{prompt}** - {ctx.author.mention} Click a button to upscale! (make larger + enhance quality)",
229
+ file=discord.File(f, f"{partial_path}.png"),
230
+ view=view,
231
+ )
232
+ else:
233
+ await ctx.send(f"{ctx.author.mention} No PNG files were found, cannot post them!")
234
+
235
+ except Exception as e:
236
+ print(f"Error: {e}")
237
+
238
+
239
+ async def deepfloydif_upscale256(index: int, path_for_upscale256_upscaling):
240
+ """upscaling function for images generated using /deepfloydif"""
241
+ try:
242
+ loop = asyncio.get_running_loop()
243
+ result_path = await loop.run_in_executor(
244
+ None, deepfloydif_upscale256_inference, index, path_for_upscale256_upscaling
245
+ )
246
+ return result_path
247
+
248
+ except Exception as e:
249
+ print(f"Error: {e}")
250
+
251
+
252
+ async def deepfloydif_upscale1024(index: int, path_for_upscale256_upscaling, prompt):
253
+ """upscaling function for images generated using /deepfloydif"""
254
+ try:
255
+ loop = asyncio.get_running_loop()
256
+ result_path = await loop.run_in_executor(
257
+ None, deepfloydif_upscale1024_inference, index, path_for_upscale256_upscaling, prompt
258
+ )
259
+ return result_path
260
+
261
+ except Exception as e:
262
+ print(f"Error: {e}")
263
+
264
+
265
+ def run_bot():
266
+ bot.run(DISCORD_TOKEN)
267
+
268
+
269
+ threading.Thread(target=run_bot).start()
270
+
271
+
272
+ welcome_message = """
273
+ ## Add this bot to your server by clicking this link:
274
+
275
+ https://discord.com/api/oauth2/authorize?client_id=1154395078735953930&permissions=51200&scope=bot
276
+
277
+ ## How to use it?
278
+
279
+ The bot can be triggered via `/deepfloydif` followed by your text prompt.
280
+
281
+ This will generate images based on the text prompt. You can upscale the images using the buttons up to 16x!
282
+
283
+ ⚠️ Note ⚠️: Please make sure this bot's command does have the same name as another command in your server.
284
+
285
+ ⚠️ Note ⚠️: Bot commands do not work in DMs with the bot as of now.
286
+ """
287
+
288
+
289
+ with gr.Blocks() as demo:
290
+ gr.Markdown(f"""
291
+ # Discord bot of https://huggingface.co/spaces/DeepFloyd/IF
292
+ {welcome_message}
293
+ """)
294
+
295
+
296
+ demo.queue(concurrency_count=100)
297
+ demo.queue(max_size=100)
298
+ demo.launch()