lunarflu HF Staff commited on
Commit
7253abd
·
1 Parent(s): 4fb358f

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (1) hide show
  1. deepfloydif/deepfloydif.py +299 -0
deepfloydif/deepfloydif.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import glob
3
+ import os
4
+ import pathlib
5
+ import random
6
+ import threading
7
+
8
+ import gradio as gr
9
+ import discord
10
+ from gradio_client import Client
11
+ from PIL import Image
12
+ from discord.ext import commands
13
+
14
+ from discord.ui import Button, View
15
+
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+ deepfloydif_client = Client("huggingface-projects/IF", HF_TOKEN)
18
+ DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")
19
+ intents = discord.Intents.all()
20
+ bot = commands.Bot(command_prefix="/", intents=intents)
21
+
22
+
23
+ @bot.event
24
+ async def on_ready():
25
+ print(f"Logged in as {bot.user} (ID: {bot.user.id})")
26
+ synced = await bot.tree.sync()
27
+ print(f"Synced commands: {', '.join([s.name for s in synced])}.")
28
+ print("------")
29
+
30
+
31
+ @bot.hybrid_command(
32
+ name="deepfloydif",
33
+ description="Enter a prompt to generate an image! Can generate realistic text, too!",
34
+ )
35
+ async def deepfloydif(ctx, prompt: str):
36
+ """DeepfloydIF stage 1 generation"""
37
+ try:
38
+ await deepfloydif_generate64(ctx, prompt)
39
+ except Exception as e:
40
+ print(f"Error: {e}")
41
+
42
+
43
+ def deepfloydif_generate64_inference(prompt):
44
+ """Generates four images based on a prompt"""
45
+ negative_prompt = ""
46
+ seed = random.randint(0, 1000)
47
+ number_of_images = 4
48
+ guidance_scale = 7
49
+ custom_timesteps_1 = "smart50"
50
+ number_of_inference_steps = 50
51
+ (
52
+ stage_1_images,
53
+ stage_1_param_path,
54
+ path_for_upscale256_upscaling,
55
+ ) = deepfloydif_client.predict(
56
+ prompt,
57
+ negative_prompt,
58
+ seed,
59
+ number_of_images,
60
+ guidance_scale,
61
+ custom_timesteps_1,
62
+ number_of_inference_steps,
63
+ api_name="/generate64",
64
+ )
65
+ return [stage_1_images, stage_1_param_path, path_for_upscale256_upscaling]
66
+
67
+
68
+ def deepfloydif_upscale256_inference(index, path_for_upscale256_upscaling):
69
+ """Upscales one of the images from deepfloydif_generate64_inference based on the chosen index"""
70
+ selected_index_for_upscale256 = index
71
+ seed_2 = 0
72
+ guidance_scale_2 = 4
73
+ custom_timesteps_2 = "smart50"
74
+ number_of_inference_steps_2 = 50
75
+ result_path = deepfloydif_client.predict(
76
+ path_for_upscale256_upscaling,
77
+ selected_index_for_upscale256,
78
+ seed_2,
79
+ guidance_scale_2,
80
+ custom_timesteps_2,
81
+ number_of_inference_steps_2,
82
+ api_name="/upscale256",
83
+ )
84
+ return result_path
85
+
86
+
87
+ def deepfloydif_upscale1024_inference(index, path_for_upscale256_upscaling, prompt):
88
+ """Upscales to stage 2, then stage 3"""
89
+ selected_index_for_upscale256 = index
90
+ seed_2 = 0 # default seed for stage 2 256 upscaling
91
+ guidance_scale_2 = 4 # default for stage 2
92
+ custom_timesteps_2 = "smart50" # default for stage 2
93
+ number_of_inference_steps_2 = 50 # default for stage 2
94
+ negative_prompt = "" # empty (not used, could add in the future)
95
+
96
+ seed_3 = 0 # default for stage 3 1024 upscaling
97
+ guidance_scale_3 = 9 # default for stage 3
98
+ number_of_inference_steps_3 = 40 # default for stage 3
99
+
100
+ result_path = deepfloydif_client.predict(
101
+ path_for_upscale256_upscaling,
102
+ selected_index_for_upscale256,
103
+ seed_2,
104
+ guidance_scale_2,
105
+ custom_timesteps_2,
106
+ number_of_inference_steps_2,
107
+ prompt,
108
+ negative_prompt,
109
+ seed_3,
110
+ guidance_scale_3,
111
+ number_of_inference_steps_3,
112
+ api_name="/upscale1024",
113
+ )
114
+ return result_path
115
+
116
+
117
+ def load_image(png_files, stage_1_images):
118
+ """Opens images as variables so we can combine them later"""
119
+ results = []
120
+ for file in png_files:
121
+ png_path = os.path.join(stage_1_images, file)
122
+ results.append(Image.open(png_path))
123
+ return results
124
+
125
+
126
+ def combine_images(png_files, stage_1_images, partial_path):
127
+ if os.environ.get("TEST_ENV") == "True":
128
+ print("Combining images for deepfloydif_generate64")
129
+ images = load_image(png_files, stage_1_images)
130
+ combined_image = Image.new("RGB", (images[0].width * 2, images[0].height * 2))
131
+ combined_image.paste(images[0], (0, 0))
132
+ combined_image.paste(images[1], (images[0].width, 0))
133
+ combined_image.paste(images[2], (0, images[0].height))
134
+ combined_image.paste(images[3], (images[0].width, images[0].height))
135
+ combined_image_path = os.path.join(stage_1_images, f"{partial_path}.png")
136
+ combined_image.save(combined_image_path)
137
+ return combined_image_path
138
+
139
+
140
+ async def deepfloydif_generate64(ctx, prompt):
141
+ """DeepfloydIF command (generate images with realistic text using slash commands)"""
142
+ try:
143
+ channel = ctx.channel
144
+ # interaction.response message can't be used to create a thread, so we create another message
145
+ message = await ctx.send(f"**{prompt}** - {ctx.author.mention} (generating...)")
146
+
147
+ loop = asyncio.get_running_loop()
148
+ result = await loop.run_in_executor(None, deepfloydif_generate64_inference, prompt)
149
+ stage_1_images = result[0]
150
+ path_for_upscale256_upscaling = result[2]
151
+
152
+ partial_path = pathlib.Path(path_for_upscale256_upscaling).name
153
+ png_files = list(glob.glob(f"{stage_1_images}/**/*.png"))
154
+
155
+ if png_files:
156
+ await message.delete()
157
+ combined_image_path = combine_images(png_files, stage_1_images, partial_path)
158
+ if os.environ.get("TEST_ENV") == "True":
159
+ print("Images combined for deepfloydif_generate64")
160
+
161
+ with Image.open(combined_image_path) as img:
162
+ width, height = img.size
163
+ new_width = width * 3
164
+ new_height = height * 3
165
+ resized_img = img.resize((new_width, new_height))
166
+ x2_combined_image_path = combined_image_path
167
+ resized_img.save(x2_combined_image_path)
168
+
169
+ # making image bigger, more readable
170
+ with open(x2_combined_image_path, "rb") as f: # was combined_image_path
171
+ button1 = Button(custom_id="0", emoji="↖")
172
+ button2 = Button(custom_id="1", emoji="↗")
173
+ button3 = Button(custom_id="2", emoji="↙")
174
+ button4 = Button(custom_id="3", emoji="↘")
175
+
176
+ async def button_callback(interaction):
177
+ index = int(interaction.data["custom_id"]) # 0,1,2,3
178
+
179
+ await interaction.response.send_message(
180
+ f"{interaction.user.mention} (upscaling...)", ephemeral=True
181
+ )
182
+ result_path = await deepfloydif_upscale256(index, path_for_upscale256_upscaling)
183
+
184
+ # create and use upscale 1024 button
185
+ with open(result_path, "rb") as f:
186
+ upscale1024 = Button(label="High-quality upscale (x4)", custom_id=str(index))
187
+ upscale1024.callback = upscale1024_callback
188
+ view = View(timeout=None)
189
+ view.add_item(upscale1024)
190
+
191
+ await interaction.delete_original_response()
192
+ await channel.send(
193
+ content=(
194
+ f"{interaction.user.mention} Here is the upscaled image! Click the button"
195
+ " to upscale even more!"
196
+ ),
197
+ file=discord.File(f, f"{prompt}.png"),
198
+ view=view,
199
+ )
200
+
201
+ async def upscale1024_callback(interaction):
202
+ index = int(interaction.data["custom_id"])
203
+
204
+ await interaction.response.send_message(
205
+ f"{interaction.user.mention} (upscaling...)", ephemeral=True
206
+ )
207
+ result_path = await deepfloydif_upscale1024(index, path_for_upscale256_upscaling, prompt)
208
+
209
+ with open(result_path, "rb") as f:
210
+ await interaction.delete_original_response()
211
+ await channel.send(
212
+ content=f"{interaction.user.mention} Here's your high-quality x16 image!",
213
+ file=discord.File(f, f"{prompt}.png"),
214
+ )
215
+
216
+ button1.callback = button_callback
217
+ button2.callback = button_callback
218
+ button3.callback = button_callback
219
+ button4.callback = button_callback
220
+
221
+ view = View(timeout=None)
222
+ view.add_item(button1)
223
+ view.add_item(button2)
224
+ view.add_item(button3)
225
+ view.add_item(button4)
226
+
227
+ # could store this message as combined_image_dfif in case it's useful for future testing
228
+ await channel.send(
229
+ f"**{prompt}** - {ctx.author.mention} Click a button to upscale! (make larger + enhance quality)",
230
+ file=discord.File(f, f"{partial_path}.png"),
231
+ view=view,
232
+ )
233
+ else:
234
+ await ctx.send(f"{ctx.author.mention} No PNG files were found, cannot post them!")
235
+
236
+ except Exception as e:
237
+ print(f"Error: {e}")
238
+
239
+
240
+ async def deepfloydif_upscale256(index: int, path_for_upscale256_upscaling):
241
+ """upscaling function for images generated using /deepfloydif"""
242
+ try:
243
+ loop = asyncio.get_running_loop()
244
+ result_path = await loop.run_in_executor(
245
+ None, deepfloydif_upscale256_inference, index, path_for_upscale256_upscaling
246
+ )
247
+ return result_path
248
+
249
+ except Exception as e:
250
+ print(f"Error: {e}")
251
+
252
+
253
+ async def deepfloydif_upscale1024(index: int, path_for_upscale256_upscaling, prompt):
254
+ """upscaling function for images generated using /deepfloydif"""
255
+ try:
256
+ loop = asyncio.get_running_loop()
257
+ result_path = await loop.run_in_executor(
258
+ None, deepfloydif_upscale1024_inference, index, path_for_upscale256_upscaling, prompt
259
+ )
260
+ return result_path
261
+
262
+ except Exception as e:
263
+ print(f"Error: {e}")
264
+
265
+
266
+ def run_bot():
267
+ bot.run(DISCORD_TOKEN)
268
+
269
+
270
+ threading.Thread(target=run_bot).start()
271
+
272
+
273
+ welcome_message = """
274
+ ## Add this bot to your server by clicking this link:
275
+
276
+ https://discord.com/api/oauth2/authorize?client_id=1154395078735953930&permissions=51200&scope=bot
277
+
278
+ ## How to use it?
279
+
280
+ The bot can be triggered via `/deepfloydif` followed by your text prompt.
281
+
282
+ This will generate images based on the text prompt. You can upscale the images using the buttons up to 16x!
283
+
284
+ ⚠️ Note ⚠️: Please make sure this bot's command does have the same name as another command in your server.
285
+
286
+ ⚠️ Note ⚠️: Bot commands do not work in DMs with the bot as of now.
287
+ """
288
+
289
+
290
+ with gr.Blocks() as demo:
291
+ gr.Markdown(f"""
292
+ # Discord bot of https://huggingface.co/spaces/DeepFloyd/IF
293
+ {welcome_message}
294
+ """)
295
+
296
+
297
+ demo.queue(concurrency_count=100)
298
+ demo.queue(max_size=100)
299
+ demo.launch()