George Krupenchenkov commited on
Commit
f6e532b
·
1 Parent(s): 84c6330

hw changes 2

Browse files
Files changed (1) hide show
  1. app.py +35 -20
app.py CHANGED
@@ -9,6 +9,10 @@ from diffusers import DiffusionPipeline
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
 
12
 
13
  if torch.cuda.is_available():
14
  torch_dtype = torch.float16
@@ -24,22 +28,22 @@ MAX_IMAGE_SIZE = 1024
24
 
25
  # @spaces.GPU #[uncomment to use ZeroGPU]
26
  def infer(
27
- model_repo_id,
28
  prompt,
29
  negative_prompt,
30
- seed,
31
  randomize_seed,
32
  width,
33
  height,
34
- guidance_scale,
35
- num_inference_steps,
 
 
36
  progress=gr.Progress(track_tqdm=True),
37
-
38
  ):
39
  if randomize_seed:
40
  seed = random.randint(0, MAX_SEED)
41
 
42
  generator = torch.Generator().manual_seed(seed)
 
43
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
44
  pipe = pipe.to(device)
45
 
@@ -56,6 +60,7 @@ def infer(
56
  return image, seed
57
 
58
 
 
59
  examples = [
60
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
61
  "An astronaut riding a green horse",
@@ -71,16 +76,9 @@ css = """
71
 
72
  with gr.Blocks(css=css) as demo:
73
  with gr.Column(elem_id="col-container"):
74
- gr.Markdown(" # Text-to-Image Gradio Template")
75
 
76
  with gr.Row():
77
- model_repo_id = gr.Text(
78
- label="model_repo_id",
79
- show_label=False,
80
- max_lines=1,
81
- placeholder="Enter your model_repo_id",
82
- container=False,
83
- )
84
  prompt = gr.Text(
85
  label="Prompt",
86
  show_label=False,
@@ -94,11 +92,27 @@ with gr.Blocks(css=css) as demo:
94
  result = gr.Image(label="Result", show_label=False)
95
 
96
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  negative_prompt = gr.Text(
98
  label="Negative prompt",
99
  max_lines=1,
100
  placeholder="Enter a negative prompt",
101
- visible=False,
102
  )
103
 
104
  seed = gr.Slider(
@@ -106,10 +120,10 @@ with gr.Blocks(css=css) as demo:
106
  minimum=0,
107
  maximum=MAX_SEED,
108
  step=1,
109
- value=0,
110
  )
111
 
112
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
113
 
114
  with gr.Row():
115
  width = gr.Slider(
@@ -134,7 +148,7 @@ with gr.Blocks(css=css) as demo:
134
  minimum=0.0,
135
  maximum=10.0,
136
  step=0.1,
137
- value=0.0, # Replace with defaults that work for your model
138
  )
139
 
140
  num_inference_steps = gr.Slider(
@@ -142,7 +156,7 @@ with gr.Blocks(css=css) as demo:
142
  minimum=1,
143
  maximum=50,
144
  step=1,
145
- value=2, # Replace with defaults that work for your model
146
  )
147
 
148
  gr.Examples(examples=examples, inputs=[prompt])
@@ -150,13 +164,13 @@ with gr.Blocks(css=css) as demo:
150
  triggers=[run_button.click, prompt.submit],
151
  fn=infer,
152
  inputs=[
153
- model_repo_id,
154
  prompt,
155
  negative_prompt,
156
- seed,
157
  randomize_seed,
158
  width,
159
  height,
 
 
160
  guidance_scale,
161
  num_inference_steps,
162
  ],
@@ -165,3 +179,4 @@ with gr.Blocks(css=css) as demo:
165
 
166
  if __name__ == "__main__":
167
  demo.launch()
 
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  # model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
12
+ model_repo_id = "CompVis/stable-diffusion-v1-4"
13
+ model_dropdown = ['stabilityai/sdxl-turbo', 'CompVis/stable-diffusion-v1-4' ]
14
+
15
+
16
 
17
  if torch.cuda.is_available():
18
  torch_dtype = torch.float16
 
28
 
29
  # @spaces.GPU #[uncomment to use ZeroGPU]
30
  def infer(
 
31
  prompt,
32
  negative_prompt,
 
33
  randomize_seed,
34
  width,
35
  height,
36
+ model_repo_id=model_repo_id,
37
+ seed=42,
38
+ guidance_scale=7,
39
+ num_inference_steps=20,
40
  progress=gr.Progress(track_tqdm=True),
 
41
  ):
42
  if randomize_seed:
43
  seed = random.randint(0, MAX_SEED)
44
 
45
  generator = torch.Generator().manual_seed(seed)
46
+
47
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
48
  pipe = pipe.to(device)
49
 
 
60
  return image, seed
61
 
62
 
63
+
64
  examples = [
65
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
66
  "An astronaut riding a green horse",
 
76
 
77
  with gr.Blocks(css=css) as demo:
78
  with gr.Column(elem_id="col-container"):
79
+ gr.Markdown(" # Text-to-Image SemaSci Template")
80
 
81
  with gr.Row():
 
 
 
 
 
 
 
82
  prompt = gr.Text(
83
  label="Prompt",
84
  show_label=False,
 
92
  result = gr.Image(label="Result", show_label=False)
93
 
94
  with gr.Accordion("Advanced Settings", open=False):
95
+ # model_repo_id = gr.Text(
96
+ # label="Model Id",
97
+ # max_lines=1,
98
+ # placeholder="Choose model",
99
+ # visible=True,
100
+ # value=model_repo_id,
101
+ # )
102
+ model_repo_id = gr.Dropdown(
103
+ label="Model Id",
104
+ choices=model_dropdown,
105
+ info="Choose model",
106
+ visible=True,
107
+ allow_custom_value=True,
108
+ value=model_repo_id,
109
+ )
110
+
111
  negative_prompt = gr.Text(
112
  label="Negative prompt",
113
  max_lines=1,
114
  placeholder="Enter a negative prompt",
115
+ visible=True,
116
  )
117
 
118
  seed = gr.Slider(
 
120
  minimum=0,
121
  maximum=MAX_SEED,
122
  step=1,
123
+ value=42,
124
  )
125
 
126
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
127
 
128
  with gr.Row():
129
  width = gr.Slider(
 
148
  minimum=0.0,
149
  maximum=10.0,
150
  step=0.1,
151
+ value=7.0, # Replace with defaults that work for your model
152
  )
153
 
154
  num_inference_steps = gr.Slider(
 
156
  minimum=1,
157
  maximum=50,
158
  step=1,
159
+ value=20, # Replace with defaults that work for your model
160
  )
161
 
162
  gr.Examples(examples=examples, inputs=[prompt])
 
164
  triggers=[run_button.click, prompt.submit],
165
  fn=infer,
166
  inputs=[
 
167
  prompt,
168
  negative_prompt,
 
169
  randomize_seed,
170
  width,
171
  height,
172
+ model_repo_id,
173
+ seed,
174
  guidance_scale,
175
  num_inference_steps,
176
  ],
 
179
 
180
  if __name__ == "__main__":
181
  demo.launch()
182
+