aiqtech commited on
Commit
fe7b6d8
·
verified ·
1 Parent(s): fedbad2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -61
app.py CHANGED
@@ -16,69 +16,67 @@ from kolors.models.unet_2d_condition import UNet2DConditionModel
16
  from diffusers import EulerDiscreteScheduler
17
  from PIL import Image
18
  from insightface.app import FaceAnalysis
19
- from insightface.data import get_image as ins_get_image
20
 
21
- # Hugging Face 토큰으로 로그인
22
  HF_TOKEN = os.getenv("HF_TOKEN")
23
  if HF_TOKEN:
24
  login(token=HF_TOKEN)
25
  print("Successfully logged in to Hugging Face Hub")
26
 
27
- # 모델 다운로드 (CPU에서)
28
  print("Downloading models...")
29
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors", token=HF_TOKEN)
30
  ckpt_dir_faceid = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", token=HF_TOKEN)
31
 
32
- # CPU에서 모델 초기화
33
  print("Loading models on CPU first...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  text_encoder = ChatGLMModel.from_pretrained(
35
  f'{ckpt_dir}/text_encoder',
36
  torch_dtype=torch.float16,
37
- token=HF_TOKEN,
38
- trust_remote_code=True,
39
- device_map=None # CPU에서 먼저 로드
40
  )
41
 
42
  tokenizer = ChatGLMTokenizer.from_pretrained(
43
  f'{ckpt_dir}/text_encoder',
44
- token=HF_TOKEN,
45
  trust_remote_code=True
46
  )
47
 
48
  vae = AutoencoderKL.from_pretrained(
49
  f"{ckpt_dir}/vae",
50
- torch_dtype=torch.float16,
51
- token=HF_TOKEN
52
  )
53
 
54
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
55
 
56
  unet = UNet2DConditionModel.from_pretrained(
57
  f"{ckpt_dir}/unet",
58
- torch_dtype=torch.float16,
59
- token=HF_TOKEN
60
  )
61
 
62
- # CLIP 모델 로딩
63
- try:
64
- clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
65
- 'openai/clip-vit-large-patch14-336',
66
- torch_dtype=torch.float16,
67
- ignore_mismatched_sizes=True,
68
- token=HF_TOKEN,
69
- use_safetensors=True
70
- )
71
- except:
72
- clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
73
- 'openai/clip-vit-large-patch14-336',
74
- torch_dtype=torch.float16,
75
- ignore_mismatched_sizes=True,
76
- token=HF_TOKEN
77
- )
78
 
79
  clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
80
 
81
- # Pipeline 생성 (CPU에서)
82
  pipe = StableDiffusionXLPipeline(
83
  vae=vae,
84
  text_encoder=text_encoder,
@@ -90,22 +88,21 @@ pipe = StableDiffusionXLPipeline(
90
  force_zeros_for_empty_prompt=False,
91
  )
92
 
93
- print("Models loaded on CPU successfully!")
94
 
95
  class FaceInfoGenerator():
96
  def __init__(self, root_dir="./.insightface/"):
97
- # CPU만 사용하도록 설정
98
  self.app = FaceAnalysis(
99
  name='antelopev2',
100
  root=root_dir,
101
- providers=['CPUExecutionProvider'] # CPU만 사용
102
  )
103
  self.app.prepare(ctx_id=0, det_size=(640, 640))
104
 
105
  def get_faceinfo_one_img(self, face_image):
106
  if face_image is None:
107
  return None
108
-
109
  face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
110
 
111
  if len(face_info) == 0:
@@ -131,8 +128,7 @@ def face_bbox_to_square(bbox):
131
  MAX_SEED = np.iinfo(np.int32).max
132
  face_info_generator = FaceInfoGenerator()
133
 
134
- # GPU 함수는 @spaces.GPU 데코레이터 내에서만 GPU 사용
135
- @spaces.GPU(duration=120) # GPU 시간 늘림
136
  def infer(prompt,
137
  image=None,
138
  negative_prompt="low quality, blurry, distorted",
@@ -145,10 +141,10 @@ def infer(prompt,
145
  gr.Warning("Please upload an image with a face.")
146
  return None, 0
147
 
148
- # Face detection (CPU에서)
149
  face_info = face_info_generator.get_faceinfo_one_img(image)
150
  if face_info is None:
151
- raise gr.Error("No face detected in the image. Please provide an image with a clear face.")
152
 
153
  face_bbox_square = face_bbox_to_square(face_info["bbox"])
154
  crop_image = image.crop(face_bbox_square)
@@ -156,15 +152,19 @@ def infer(prompt,
156
  crop_image = [crop_image]
157
  face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
158
 
159
- # GPU로 이동 (spaces.GPU 내에서만)
160
- device = "cuda"
161
  global pipe
162
 
163
- # 모델을 GPU로 이��
164
- pipe = pipe.to(device)
 
 
 
 
165
  face_embeds = face_embeds.to(device, dtype=torch.float16)
166
 
167
- # IP Adapter 로딩
168
  pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device=device)
169
  pipe.set_face_fidelity_scale(0.8)
170
 
@@ -173,9 +173,9 @@ def infer(prompt,
173
 
174
  generator = torch.Generator(device=device).manual_seed(seed)
175
 
176
- # 이미지 생성
177
  with torch.no_grad():
178
- with torch.autocast(device):
179
  result = pipe(
180
  prompt=prompt,
181
  negative_prompt=negative_prompt,
@@ -189,34 +189,41 @@ def infer(prompt,
189
  face_insightface_embeds=face_embeds
190
  ).images[0]
191
 
 
 
 
 
 
 
 
192
  return result, seed
193
 
194
  css = """
195
  footer {
196
  visibility: hidden;
197
  }
198
- #col-left {
199
- margin: 0 auto;
200
  max-width: 640px;
201
- }
202
- #col-right {
203
  margin: 0 auto;
204
- max-width: 640px;
 
 
205
  }
206
  """
207
 
 
208
  with gr.Blocks(theme="soft", css=css) as Kolors:
209
  gr.HTML(
210
  """
211
  <div style='text-align: center;'>
212
  <h1>🎨 Kolors Face ID - AI Portrait Generator</h1>
213
- <p>Upload a face photo and create stunning AI portraits with text prompts!</p>
214
  <div style='display:flex; justify-content:center; gap:12px; margin-top:20px;'>
215
  <a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
216
- <img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
217
  </a>
218
  <a href="https://discord.gg/openfreeai" target="_blank">
219
- <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
220
  </a>
221
  </div>
222
  </div>
@@ -227,27 +234,26 @@ with gr.Blocks(theme="soft", css=css) as Kolors:
227
  with gr.Column(elem_id="col-left"):
228
  prompt = gr.Textbox(
229
  label="Prompt",
230
- placeholder="e.g., A professional portrait in business attire, studio lighting",
231
  lines=3,
232
- value="A professional portrait photo, high quality, detailed face"
233
  )
234
- image = gr.Image(label="Upload Face Image", type="pil", height=400)
235
 
236
  with gr.Accordion("Advanced Settings", open=False):
237
  negative_prompt = gr.Textbox(
238
  label="Negative prompt",
239
- value="low quality, blurry, distorted, disfigured"
240
  )
241
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=66)
242
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
243
- with gr.Row():
244
- guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
245
- num_inference_steps = gr.Slider(label="Inference steps", minimum=10, maximum=50, step=1, value=25)
246
 
247
- button = gr.Button("🎨 Generate Portrait", variant="primary", scale=1)
248
 
249
  with gr.Column(elem_id="col-right"):
250
- result = gr.Image(label="Generated Portrait", show_label=True)
251
  seed_used = gr.Number(label="Seed Used", precision=0)
252
 
253
  button.click(
 
16
  from diffusers import EulerDiscreteScheduler
17
  from PIL import Image
18
  from insightface.app import FaceAnalysis
 
19
 
20
+ # Login with HF token
21
  HF_TOKEN = os.getenv("HF_TOKEN")
22
  if HF_TOKEN:
23
  login(token=HF_TOKEN)
24
  print("Successfully logged in to Hugging Face Hub")
25
 
26
+ # Download models
27
  print("Downloading models...")
28
  ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors", token=HF_TOKEN)
29
  ckpt_dir_faceid = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus", token=HF_TOKEN)
30
 
 
31
  print("Loading models on CPU first...")
32
+
33
+ # Fix for ChatGLMTokenizer - monkey patch the _pad method
34
+ original_chatglm_pad = ChatGLMTokenizer._pad if hasattr(ChatGLMTokenizer, '_pad') else None
35
+
36
+ def fixed_pad(self, *args, **kwargs):
37
+ # Remove the unexpected 'padding_side' argument if present
38
+ kwargs.pop('padding_side', None)
39
+ if original_chatglm_pad:
40
+ return original_chatglm_pad(self, *args, **kwargs)
41
+ else:
42
+ return super(ChatGLMTokenizer, self)._pad(*args, **kwargs)
43
+
44
+ ChatGLMTokenizer._pad = fixed_pad
45
+
46
+ # Load models
47
  text_encoder = ChatGLMModel.from_pretrained(
48
  f'{ckpt_dir}/text_encoder',
49
  torch_dtype=torch.float16,
50
+ trust_remote_code=True
 
 
51
  )
52
 
53
  tokenizer = ChatGLMTokenizer.from_pretrained(
54
  f'{ckpt_dir}/text_encoder',
 
55
  trust_remote_code=True
56
  )
57
 
58
  vae = AutoencoderKL.from_pretrained(
59
  f"{ckpt_dir}/vae",
60
+ torch_dtype=torch.float16
 
61
  )
62
 
63
  scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
64
 
65
  unet = UNet2DConditionModel.from_pretrained(
66
  f"{ckpt_dir}/unet",
67
+ torch_dtype=torch.float16
 
68
  )
69
 
70
+ # Load CLIP
71
+ clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
72
+ 'openai/clip-vit-large-patch14-336',
73
+ torch_dtype=torch.float16,
74
+ use_safetensors=True
75
+ )
 
 
 
 
 
 
 
 
 
 
76
 
77
  clip_image_processor = CLIPImageProcessor(size=336, crop_size=336)
78
 
79
+ # Create pipeline
80
  pipe = StableDiffusionXLPipeline(
81
  vae=vae,
82
  text_encoder=text_encoder,
 
88
  force_zeros_for_empty_prompt=False,
89
  )
90
 
91
+ print("Models loaded successfully!")
92
 
93
  class FaceInfoGenerator():
94
  def __init__(self, root_dir="./.insightface/"):
 
95
  self.app = FaceAnalysis(
96
  name='antelopev2',
97
  root=root_dir,
98
+ providers=['CPUExecutionProvider']
99
  )
100
  self.app.prepare(ctx_id=0, det_size=(640, 640))
101
 
102
  def get_faceinfo_one_img(self, face_image):
103
  if face_image is None:
104
  return None
105
+
106
  face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
107
 
108
  if len(face_info) == 0:
 
128
  MAX_SEED = np.iinfo(np.int32).max
129
  face_info_generator = FaceInfoGenerator()
130
 
131
+ @spaces.GPU(duration=120)
 
132
  def infer(prompt,
133
  image=None,
134
  negative_prompt="low quality, blurry, distorted",
 
141
  gr.Warning("Please upload an image with a face.")
142
  return None, 0
143
 
144
+ # Face detection on CPU
145
  face_info = face_info_generator.get_faceinfo_one_img(image)
146
  if face_info is None:
147
+ raise gr.Error("No face detected. Please upload an image with a clear face.")
148
 
149
  face_bbox_square = face_bbox_to_square(face_info["bbox"])
150
  crop_image = image.crop(face_bbox_square)
 
152
  crop_image = [crop_image]
153
  face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
154
 
155
+ # Move to GPU
156
+ device = torch.device("cuda")
157
  global pipe
158
 
159
+ # Move models to GPU
160
+ pipe.vae = pipe.vae.to(device)
161
+ pipe.text_encoder = pipe.text_encoder.to(device)
162
+ pipe.unet = pipe.unet.to(device)
163
+ pipe.face_clip_encoder = pipe.face_clip_encoder.to(device)
164
+
165
  face_embeds = face_embeds.to(device, dtype=torch.float16)
166
 
167
+ # Load IP adapter
168
  pipe.load_ip_adapter_faceid_plus(f'{ckpt_dir_faceid}/ipa-faceid-plus.bin', device=device)
169
  pipe.set_face_fidelity_scale(0.8)
170
 
 
173
 
174
  generator = torch.Generator(device=device).manual_seed(seed)
175
 
176
+ # Generate image
177
  with torch.no_grad():
178
+ with torch.autocast(device_type="cuda", dtype=torch.float16):
179
  result = pipe(
180
  prompt=prompt,
181
  negative_prompt=negative_prompt,
 
189
  face_insightface_embeds=face_embeds
190
  ).images[0]
191
 
192
+ # Move models back to CPU to free GPU memory
193
+ pipe.vae = pipe.vae.to("cpu")
194
+ pipe.text_encoder = pipe.text_encoder.to("cpu")
195
+ pipe.unet = pipe.unet.to("cpu")
196
+ pipe.face_clip_encoder = pipe.face_clip_encoder.to("cpu")
197
+ torch.cuda.empty_cache()
198
+
199
  return result, seed
200
 
201
  css = """
202
  footer {
203
  visibility: hidden;
204
  }
205
+ #col-left, #col-right {
 
206
  max-width: 640px;
 
 
207
  margin: 0 auto;
208
+ }
209
+ .gr-button {
210
+ max-width: 100%;
211
  }
212
  """
213
 
214
+ # Gradio interface
215
  with gr.Blocks(theme="soft", css=css) as Kolors:
216
  gr.HTML(
217
  """
218
  <div style='text-align: center;'>
219
  <h1>🎨 Kolors Face ID - AI Portrait Generator</h1>
220
+ <p>Upload a face photo and create stunning AI portraits!</p>
221
  <div style='display:flex; justify-content:center; gap:12px; margin-top:20px;'>
222
  <a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
223
+ <img src="https://img.shields.io/badge/OpenFree-BEST%20AI-blue?style=for-the-badge" alt="OpenFree">
224
  </a>
225
  <a href="https://discord.gg/openfreeai" target="_blank">
226
+ <img src="https://img.shields.io/badge/Discord-OpenFree%20AI-purple?style=for-the-badge&logo=discord" alt="Discord">
227
  </a>
228
  </div>
229
  </div>
 
234
  with gr.Column(elem_id="col-left"):
235
  prompt = gr.Textbox(
236
  label="Prompt",
237
+ placeholder="Describe the portrait style you want...",
238
  lines=3,
239
+ value="A professional portrait photo, high quality"
240
  )
241
+ image = gr.Image(label="Upload Face Image", type="pil", height=300)
242
 
243
  with gr.Accordion("Advanced Settings", open=False):
244
  negative_prompt = gr.Textbox(
245
  label="Negative prompt",
246
+ value="low quality, blurry, distorted"
247
  )
248
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=66)
249
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
250
+ guidance_scale = gr.Slider(label="Guidance", minimum=1, maximum=10, step=0.5, value=5)
251
+ num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, step=5, value=25)
 
252
 
253
+ button = gr.Button("🎨 Generate Portrait", variant="primary")
254
 
255
  with gr.Column(elem_id="col-right"):
256
+ result = gr.Image(label="Generated Portrait")
257
  seed_used = gr.Number(label="Seed Used", precision=0)
258
 
259
  button.click(