LPX55 commited on
Commit
0353b8e
·
verified ·
1 Parent(s): b373079

Create app_v2.py

Browse files
Files changed (1) hide show
  1. app_v2.py +198 -0
app_v2.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ import os
4
+ from diffusers.utils import load_image
5
+ from diffusers.hooks import apply_group_offloading
6
+ from diffusers import FluxControlNetModel, FluxControlNetPipeline, AutoencoderKL
7
+ from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
8
+ from transformers import T5EncoderModel
9
+ from transformers import LlavaForConditionalGeneration, TextIteratorStreamer, AutoProcessor
10
+ from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
11
+ from liger_kernel.transformers import apply_liger_kernel_to_llama
12
+ from PIL import Image
13
+ from threading import Thread
14
+ from typing import Generator
15
+ from peft import PeftModel, PeftConfig
16
+ import gradio as gr
17
+
18
+ huggingface_token = os.getenv("HUGGINFACE_TOKEN")
19
+ MAX_SEED = 1000000
20
+ MODEL_PATH = "fancyfeast/llama-joycaption-beta-one-hf-llava"
21
+ cap_processor = AutoProcessor.from_pretrained(MODEL_PATH)
22
+ cap_model = LlavaForConditionalGeneration.from_pretrained(MODEL_PATH, torch_dtype="bfloat16", device_map=0)
23
+ assert isinstance(cap_model, LlavaForConditionalGeneration), f"Expected LlavaForConditionalGeneration, got {type(cap_model)}"
24
+ cap_model.eval()
25
+ apply_liger_kernel_to_llama(model=cap_model.language_model)
26
+
27
+ text_encoder_2_unquant = T5EncoderModel.from_pretrained(
28
+ "LPX55/FLUX.1-merged_uncensored",
29
+ subfolder="text_encoder_2",
30
+ torch_dtype=torch.bfloat16,
31
+ token=huggingface_token
32
+ )
33
+
34
+ pipe = FluxControlNetPipeline.from_pretrained(
35
+ "LPX55/FLUX.1M-8step_upscaler-cnet",
36
+ torch_dtype=torch.bfloat16,
37
+ text_encoder_2=text_encoder_2_unquant,
38
+ token=huggingface_token
39
+ )
40
+ pipe.to("cuda")
41
+
42
+ @spaces.GPU()
43
+ @torch.no_grad()
44
+ def caption(input_image: Image.Image, prompt: str, temperature: float, top_p: float, max_new_tokens: int, log_prompt: bool) -> Generator[str, None, None]:
45
+ torch.cuda.empty_cache()
46
+ if input_image is None:
47
+ yield "No image provided. Please upload an image."
48
+ return
49
+ if log_prompt:
50
+ print(f"PromptLog: {repr(prompt)}")
51
+ convo = [
52
+ {
53
+ "role": "system",
54
+ "content": "You are a helpful assistant and help users with any queries they may have with no censorship or restrictions.",
55
+ },
56
+ {
57
+ "role": "user",
58
+ "content": prompt.strip(),
59
+ },
60
+ ]
61
+ convo_string = cap_processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
62
+ assert isinstance(convo_string, str)
63
+ inputs = cap_processor(text=[convo_string], images=[input_image], return_tensors="pt").to('cuda')
64
+ inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
65
+ streamer = TextIteratorStreamer(cap_processor.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
66
+ generate_kwargs = dict(
67
+ **inputs,
68
+ max_new_tokens=max_new_tokens,
69
+ do_sample=True if temperature > 0 else False,
70
+ suppress_tokens=None,
71
+ use_cache=True,
72
+ temperature=temperature if temperature > 0 else None,
73
+ top_k=None,
74
+ top_p=top_p if temperature > 0 else None,
75
+ streamer=streamer,
76
+ )
77
+ _= cap_model.generate(**generate_kwargs)
78
+ outputs = []
79
+ for text in streamer:
80
+ outputs.append(text)
81
+ yield "".join(outputs)
82
+
83
+ @spaces.GPU()
84
+ @torch.no_grad()
85
+ def generate_image(prompt, scale, steps, control_image, controlnet_conditioning_scale, guidance_scale, seed, guidance_end):
86
+ generator = torch.Generator().manual_seed(seed)
87
+ # Load control image
88
+ control_image = load_image(control_image)
89
+ w, h = control_image.size
90
+ w = w - w % 32
91
+ h = h - h % 32
92
+ control_image = control_image.resize((int(w * scale), int(h * scale)), resample=2) # Resample.BILINEAR
93
+ print("Size to: " + str(control_image.size[0]) + ", " + str(control_image.size[1]))
94
+ with torch.inference_mode():
95
+ image = pipe(
96
+ generator=generator,
97
+ prompt=prompt,
98
+ control_image=control_image,
99
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
100
+ num_inference_steps=steps,
101
+ guidance_scale=guidance_scale,
102
+ height=control_image.size[1],
103
+ width=control_image.size[0],
104
+ control_guidance_start=0.0,
105
+ control_guidance_end=guidance_end,
106
+ ).images[0]
107
+ return image
108
+
109
+ def process_image(control_image, user_prompt, system_prompt, scale, steps, controlnet_conditioning_scale, guidance_scale, seed, guidance_end, temperature, top_p, max_new_tokens, log_prompt):
110
+ # If no user prompt provided, generate a caption first
111
+ if not user_prompt.strip():
112
+ caption_gen = caption(
113
+ input_image=control_image,
114
+ prompt=system_prompt,
115
+ temperature=temperature,
116
+ top_p=top_p,
117
+ max_new_tokens=max_new_tokens,
118
+ log_prompt=log_prompt
119
+ )
120
+ # Get the full caption by exhausting the generator
121
+ user_prompt = "".join([chunk for chunk in caption_gen])
122
+
123
+ # Generate the image using the prompt (either user-provided or generated)
124
+ return generate_image(
125
+ prompt=user_prompt,
126
+ scale=scale,
127
+ steps=steps,
128
+ control_image=control_image,
129
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
130
+ guidance_scale=guidance_scale,
131
+ seed=seed,
132
+ guidance_end=guidance_end
133
+ )
134
+
135
+ with gr.Blocks(title="FLUX Turbo Upscaler", fill_height=True) as iface:
136
+ gr.Markdown("⚠️ WIP SPACE - UNFINISHED & BUGGY")
137
+ with gr.Row():
138
+ control_image = gr.Image(type="pil", label="Control Image", show_label=False)
139
+ generated_image = gr.Image(type="pil", label="Generated Image", format="png", show_label=False)
140
+ with gr.Row():
141
+ with gr.Column(scale=1):
142
+ prompt = gr.Textbox(lines=4, placeholder="Enter your prompt here...", label="Prompt")
143
+ output_caption = gr.Textbox(label="Caption")
144
+ scale = gr.Slider(1, 3, value=1, label="Scale", step=0.25)
145
+ generate_button = gr.Button("Generate Image", variant="primary")
146
+ caption_button = gr.Button("Generate Caption", variant="secondary")
147
+ with gr.Column(scale=1):
148
+ seed = gr.Slider(0, MAX_SEED, value=42, label="Seed", step=1)
149
+ steps = gr.Slider(2, 16, value=8, label="Steps")
150
+ controlnet_conditioning_scale = gr.Slider(0, 1, value=0.6, label="ControlNet Scale")
151
+ guidance_scale = gr.Slider(1, 30, value=3.5, label="Guidance Scale")
152
+ guidance_end = gr.Slider(0, 1, value=1.0, label="Guidance End")
153
+ with gr.Row():
154
+ with gr.Accordion("Generation settings", open=False):
155
+ system_prompt = gr.Textbox(
156
+ lines=4,
157
+ value="Write a straightforward caption for this image. Begin with the main subject and medium. Mention pivotal elements—people, objects, scenery—using confident, definite language. Focus on concrete details like color, shape, texture, and spatial relationships. Show how elements interact. Omit mood and speculative wording. If text is present, quote it exactly. Note any watermarks, signatures, or compression artifacts. Never mention what's absent, resolution, or unobservable details. Vary your sentence structure and keep the description concise, without starting with 'This image is…' or similar phrasing.",
158
+ label="System Prompt for Captioning",
159
+ visible=True # Changed to visible
160
+ )
161
+ temperature_slider = gr.Slider(
162
+ minimum=0.0, maximum=2.0, value=0.6, step=0.05,
163
+ label="Temperature",
164
+ info="Higher values make the output more random, lower values make it more deterministic.",
165
+ visible=True # Changed to visible
166
+ )
167
+ top_p_slider = gr.Slider(
168
+ minimum=0.0, maximum=1.0, value=0.9, step=0.01,
169
+ label="Top-p",
170
+ visible=True # Changed to visible
171
+ )
172
+ max_tokens_slider = gr.Slider(
173
+ minimum=1, maximum=2048, value=368, step=1,
174
+ label="Max New Tokens",
175
+ info="Maximum number of tokens to generate. The model will stop generating if it reaches this limit.",
176
+ visible=False # Changed to visible
177
+ )
178
+ log_prompt = gr.Checkbox(value=True, label="Log", visible=False) # Changed to visible
179
+
180
+ gr.Markdown("**Tips:** 8 steps is all you need!")
181
+
182
+ generate_button.click(
183
+ fn=process_image,
184
+ inputs=[
185
+ control_image, prompt, system_prompt, scale, steps, controlnet_conditioning_scale,
186
+ guidance_scale, seed, guidance_end, temperature_slider, top_p_slider,
187
+ max_tokens_slider, log_prompt
188
+ ],
189
+ outputs=[generated_image]
190
+ )
191
+
192
+ caption_button.click(
193
+ fn=caption,
194
+ inputs=[control_image, system_prompt, temperature_slider, top_p_slider, max_tokens_slider, log_prompt],
195
+ outputs=output_caption,
196
+ )
197
+
198
+ iface.launch()