lisonallen commited on
Commit
65762c4
·
1 Parent(s): 8718399

Improve image generation and display

Browse files
Files changed (1) hide show
  1. app.py +63 -16
app.py CHANGED
@@ -4,6 +4,8 @@ import random
4
  import logging
5
  import sys
6
  import os
 
 
7
 
8
  # 设置日志记录
9
  logging.basicConfig(level=logging.INFO,
@@ -54,6 +56,12 @@ try:
54
  except Exception as e:
55
  logger.error(f"Failed to patch Gradio: {str(e)}")
56
 
 
 
 
 
 
 
57
  # 加载模型
58
  try:
59
  from diffusers import DiffusionPipeline
@@ -85,7 +93,11 @@ def generate_image(prompt):
85
  try:
86
  if pipe is None:
87
  logger.error("Model not loaded, cannot generate image")
88
- return None
 
 
 
 
89
 
90
  logger.info(f"Generating image for prompt: {prompt}")
91
  seed = random.randint(0, MAX_SEED)
@@ -99,26 +111,60 @@ def generate_image(prompt):
99
  num_inference_steps=2,
100
  generator=generator
101
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  except IndexError as e:
103
  logger.error(f"Index error in pipe: {str(e)}")
104
- return None
105
 
106
- logger.info("Image generation successful")
107
- return image
108
  except Exception as e:
109
  logger.error(f"Error generating image: {str(e)}")
110
- return None
111
 
112
  # 创建简单的 Gradio 界面,禁用示例缓存
113
- demo = gr.Interface(
114
- fn=generate_image,
115
- inputs=gr.Textbox(label="Enter your prompt"),
116
- outputs=gr.Image(label="Generated Image"),
117
- title="SDXL Turbo Text-to-Image",
118
- description="Enter a text prompt to generate an image.",
119
- examples=["A cute cat", "Sunset over mountains"], # 减少示例数量
120
- cache_examples=False # 禁用示例缓存以避免文件访问错误
121
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # 启动应用
124
  if __name__ == "__main__":
@@ -126,9 +172,10 @@ if __name__ == "__main__":
126
  logger.info("Starting Gradio app")
127
  # 添加更多启动选项,以提高稳定性
128
  demo.launch(
 
129
  show_error=True,
130
- max_threads=1, # 降低并发以避免竞态条件
131
- share=False
132
  )
133
  except Exception as e:
134
  logger.error(f"Error launching app: {str(e)}")
 
4
  import logging
5
  import sys
6
  import os
7
+ from PIL import Image as PILImage
8
+ import io
9
 
10
  # 设置日志记录
11
  logging.basicConfig(level=logging.INFO,
 
56
  except Exception as e:
57
  logger.error(f"Failed to patch Gradio: {str(e)}")
58
 
59
+ # 创建一个简单的示例图像,在模型加载失败或生成失败时使用
60
+ def create_dummy_image():
61
+ # 创建一个256x256的红色图像
62
+ img = PILImage.new('RGB', (256, 256), color = (255, 0, 0))
63
+ return img
64
+
65
  # 加载模型
66
  try:
67
  from diffusers import DiffusionPipeline
 
93
  try:
94
  if pipe is None:
95
  logger.error("Model not loaded, cannot generate image")
96
+ return create_dummy_image()
97
+
98
+ if not prompt or prompt.strip() == "":
99
+ prompt = "A beautiful landscape"
100
+ logger.info(f"Empty prompt, using default: {prompt}")
101
 
102
  logger.info(f"Generating image for prompt: {prompt}")
103
  seed = random.randint(0, MAX_SEED)
 
111
  num_inference_steps=2,
112
  generator=generator
113
  ).images[0]
114
+
115
+ # 确保图像是 PIL.Image 类型
116
+ if not isinstance(image, PILImage.Image):
117
+ logger.warning(f"Converting image from {type(image)} to PIL.Image")
118
+ if hasattr(image, 'numpy'):
119
+ image = PILImage.fromarray(image.numpy())
120
+ else:
121
+ image = PILImage.fromarray(np.array(image))
122
+
123
+ # 转换为 RGB 模式,确保兼容性
124
+ if image.mode != 'RGB':
125
+ image = image.convert('RGB')
126
+
127
+ logger.info("Image generation successful")
128
+ # 保存图像以供调试
129
+ debug_path = "debug_image.jpg"
130
+ image.save(debug_path)
131
+ logger.info(f"Debug image saved to {debug_path}")
132
+
133
+ return image
134
  except IndexError as e:
135
  logger.error(f"Index error in pipe: {str(e)}")
136
+ return create_dummy_image()
137
 
 
 
138
  except Exception as e:
139
  logger.error(f"Error generating image: {str(e)}")
140
+ return create_dummy_image()
141
 
142
  # 创建简单的 Gradio 界面,禁用示例缓存
143
+ with gr.Blocks(title="SDXL Turbo Text-to-Image") as demo:
144
+ gr.Markdown("# SDXL Turbo Text-to-Image Generator")
145
+
146
+ with gr.Row():
147
+ with gr.Column():
148
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your text prompt here...")
149
+ generate_button = gr.Button("Generate Image")
150
+
151
+ with gr.Column():
152
+ image_output = gr.Image(label="Generated Image", type="pil")
153
+
154
+ # 示例
155
+ gr.Examples(
156
+ examples=["A cute cat", "Sunset over mountains"],
157
+ inputs=prompt_input,
158
+ outputs=image_output,
159
+ fn=generate_image,
160
+ cache_examples=False
161
+ )
162
+
163
+ # 绑定生成按钮
164
+ generate_button.click(fn=generate_image, inputs=prompt_input, outputs=image_output)
165
+
166
+ # 直接绑定文本框的提交事件
167
+ prompt_input.submit(fn=generate_image, inputs=prompt_input, outputs=image_output)
168
 
169
  # 启动应用
170
  if __name__ == "__main__":
 
172
  logger.info("Starting Gradio app")
173
  # 添加更多启动选项,以提高稳定性
174
  demo.launch(
175
+ debug=True,
176
  show_error=True,
177
+ server_name="0.0.0.0", # 确保可以从外部访问
178
+ max_threads=1 # 降低并发以避免竞态条件
179
  )
180
  except Exception as e:
181
  logger.error(f"Error launching app: {str(e)}")