yeshog50 commited on
Commit
8074036
Β·
verified Β·
1 Parent(s): 5e76037

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -25
app.py CHANGED
@@ -1,48 +1,52 @@
1
  import os
2
  import random
3
- import gradio as gr
4
  import torch
 
5
  from diffusers import StableDiffusionPipeline
 
 
6
 
7
- # Use the anime-specific Waifu Diffusion model
8
  BASE_MODEL = "hakurei/waifu-diffusion"
9
  MODEL_CACHE = "model_cache"
10
  os.makedirs(MODEL_CACHE, exist_ok=True)
11
 
12
  def get_pipeline():
13
- # Load the Waifu Diffusion model (anime style) on CPU
 
 
 
 
 
 
 
 
14
  pipe = StableDiffusionPipeline.from_pretrained(
15
  BASE_MODEL,
16
  torch_dtype=torch.float32,
17
  cache_dir=MODEL_CACHE,
18
- safety_checker=None,
 
19
  use_safetensors=True
20
  )
 
21
  # Move to CPU and enable memory optimizations
22
  pipe = pipe.to("cpu")
23
  pipe.enable_attention_slicing()
24
- pipe.enable_model_cpu_offload()
 
25
  return pipe
26
 
27
- # Load pipeline once at start
28
  pipeline = get_pipeline()
29
 
30
- def generate_image(
31
- prompt: str,
32
- negative_prompt: str = "",
33
- width: int = 768,
34
- height: int = 768,
35
- seed: int = -1,
36
- guidance_scale: float = 7.5,
37
- num_inference_steps: int = 25
38
- ):
39
- # Generate a random seed if none provided
40
  if seed == -1:
41
- seed = random.randint(0, 2**31-1)
42
  generator = torch.Generator(device="cpu").manual_seed(seed)
43
- # Generate the image with the pipeline
44
  with torch.no_grad():
45
- image = pipeline(
46
  prompt=f"anime style, {prompt}",
47
  negative_prompt=negative_prompt,
48
  width=width,
@@ -50,26 +54,30 @@ def generate_image(
50
  guidance_scale=guidance_scale,
51
  num_inference_steps=num_inference_steps,
52
  generator=generator
53
- ).images[0]
 
54
  return image, seed
55
 
56
- # Gradio Interface
57
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
58
- gr.Markdown("# πŸŒ€ Anime Image Generator (CPU only)")
59
  with gr.Row():
60
  with gr.Column():
61
- prompt = gr.Textbox(label="Prompt", lines=3, placeholder="e.g., 1girl, pink hair, night sky...")
62
- negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, lowres, disfigured")
63
  generate_btn = gr.Button("Generate", variant="primary")
 
64
  with gr.Accordion("Advanced", open=False):
65
  width = gr.Slider(512, 1024, value=768, step=64, label="Width")
66
  height = gr.Slider(512, 1024, value=768, step=64, label="Height")
67
  guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance")
68
  steps = gr.Slider(15, 50, value=25, step=1, label="Steps")
69
  seed = gr.Number(label="Seed", value=-1)
 
70
  with gr.Column():
71
  output_image = gr.Image(label="Result", type="pil")
72
  used_seed = gr.Textbox(label="Used Seed")
 
73
  generate_btn.click(
74
  generate_image,
75
  inputs=[prompt, negative_prompt, width, height, seed, guidance, steps],
@@ -77,4 +85,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
77
  )
78
 
79
  if __name__ == "__main__":
80
- demo.launch(share=True)
 
1
  import os
2
  import random
 
3
  import torch
4
+ import gradio as gr
5
  from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
7
+ from transformers import CLIPFeatureExtractor
8
 
9
+ # Configuration
10
  BASE_MODEL = "hakurei/waifu-diffusion"
11
  MODEL_CACHE = "model_cache"
12
  os.makedirs(MODEL_CACHE, exist_ok=True)
13
 
14
  def get_pipeline():
15
+ # Load the safety checker
16
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
17
+ "CompVis/stable-diffusion-safety-checker"
18
+ )
19
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
20
+ "openai/clip-vit-base-patch32"
21
+ )
22
+
23
+ # Load the pipeline
24
  pipe = StableDiffusionPipeline.from_pretrained(
25
  BASE_MODEL,
26
  torch_dtype=torch.float32,
27
  cache_dir=MODEL_CACHE,
28
+ safety_checker=safety_checker,
29
+ feature_extractor=feature_extractor,
30
  use_safetensors=True
31
  )
32
+
33
  # Move to CPU and enable memory optimizations
34
  pipe = pipe.to("cpu")
35
  pipe.enable_attention_slicing()
36
+ pipe.enable_model_cpu_offload() # βœ… Requires accelerate>=0.17.0
37
+
38
  return pipe
39
 
40
+ # Load once
41
  pipeline = get_pipeline()
42
 
43
+ def generate_image(prompt, negative_prompt="", width=768, height=768, seed=-1, guidance_scale=7.5, num_inference_steps=25):
 
 
 
 
 
 
 
 
 
44
  if seed == -1:
45
+ seed = random.randint(0, 2**31 - 1)
46
  generator = torch.Generator(device="cpu").manual_seed(seed)
47
+
48
  with torch.no_grad():
49
+ output = pipeline(
50
  prompt=f"anime style, {prompt}",
51
  negative_prompt=negative_prompt,
52
  width=width,
 
54
  guidance_scale=guidance_scale,
55
  num_inference_steps=num_inference_steps,
56
  generator=generator
57
+ )
58
+ image = output.images[0]
59
  return image, seed
60
 
61
+ # Gradio UI
62
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
63
+ gr.Markdown("# πŸŒ€ Anime Image Generator (Safe for Public)")
64
  with gr.Row():
65
  with gr.Column():
66
+ prompt = gr.Textbox(label="Prompt", lines=3)
67
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="blurry, lowres, bad anatomy")
68
  generate_btn = gr.Button("Generate", variant="primary")
69
+
70
  with gr.Accordion("Advanced", open=False):
71
  width = gr.Slider(512, 1024, value=768, step=64, label="Width")
72
  height = gr.Slider(512, 1024, value=768, step=64, label="Height")
73
  guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance")
74
  steps = gr.Slider(15, 50, value=25, step=1, label="Steps")
75
  seed = gr.Number(label="Seed", value=-1)
76
+
77
  with gr.Column():
78
  output_image = gr.Image(label="Result", type="pil")
79
  used_seed = gr.Textbox(label="Used Seed")
80
+
81
  generate_btn.click(
82
  generate_image,
83
  inputs=[prompt, negative_prompt, width, height, seed, guidance, steps],
 
85
  )
86
 
87
  if __name__ == "__main__":
88
+ demo.launch()