LinB203
init
0c8d55e
import sys
sys.path.append("..")
from transformers import AutoTokenizer, AutoProcessor
from univa.models.qwen2p5vl.modeling_univa_qwen2p5vl import UnivaQwen2p5VLForConditionalGeneration
from transformers import SiglipImageProcessor, SiglipVisionModel
from univa.utils.flux_pipeline import FluxPipeline
from univa.utils.get_ocr import get_ocr_result
from univa.utils.denoiser_prompt_embedding_flux import encode_prompt
from qwen_vl_utils import process_vision_info
from univa.utils.anyres_util import dynamic_resize
import torch
from PIL import Image
from transformers import set_seed
from torch import nn
import os
import argparse
seed = 42
set_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
generate_image_temp = './generate_image_{}.png'
def load_main_model_and_processor(
model_path,
device,
min_pixels=448*448,
max_pixels=448*448
):
# Load model and processor
model = UnivaQwen2p5VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to(device)
task_head = nn.Sequential(
nn.Linear(3584, 10240),
nn.SiLU(),
nn.Dropout(0.3),
nn.Linear(10240, 2)
).to(device)
task_head.load_state_dict(torch.load(os.path.join(args.model_path, 'task_head_final.pt')))
task_head.eval()
processor = AutoProcessor.from_pretrained(
model_path,
min_pixels=min_pixels, max_pixels=max_pixels
)
return model, task_head, processor
def load_pipe(
denoiser,
flux_path,
device,
):
pipe = FluxPipeline.from_pretrained(
flux_path,
transformer=denoiser,
torch_dtype=torch.bfloat16,
)
pipe = pipe.to(device)
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
text_encoders = [
pipe.text_encoder,
pipe.text_encoder_2,
]
return pipe, tokenizers, text_encoders
def load_siglip_and_processor(
siglip_path,
device,
):
siglip_processor, siglip_model = None, None
if siglip_path:
siglip_processor = SiglipImageProcessor.from_pretrained(
siglip_path
)
siglip_model = SiglipVisionModel.from_pretrained(
siglip_path,
torch_dtype=torch.bfloat16,
).to(device)
return siglip_processor, siglip_model
def preprocess_siglip_pixel_values(siglip_model, siglip_processor, image_paths):
siglip_pixel_values = []
for image_path in image_paths:
siglip_pixel_value = siglip_processor.preprocess(
images=Image.open(image_path).convert('RGB'),
do_resize=True, return_tensors="pt", do_convert_rgb=True
).pixel_values # 1 c h w
siglip_pixel_values.append(siglip_pixel_value)
siglip_pixel_values = torch.concat(siglip_pixel_values) # b c h w
siglip_pixel_values = siglip_pixel_values.to(siglip_model.device)
siglip_hidden_states = siglip_model(siglip_pixel_values).last_hidden_state
return siglip_hidden_states
def update_size(i1, i2, anyres='any_11ratio', anchor_pixels=1024*1024):
shapes = []
for p in (i1, i2):
if p:
im = Image.open(p)
w, h = im.size
shapes.append((w, h))
if not shapes:
return int(anchor_pixels**0.5), int(anchor_pixels**0.5)
if len(shapes) == 1:
w, h = shapes[0]
else:
w = sum(s[0] for s in shapes) / len(shapes)
h = sum(s[1] for s in shapes) / len(shapes)
new_h, new_w = dynamic_resize(int(h), int(w), anyres, anchor_pixels=anchor_pixels)
return new_h, new_w
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, task_head, processor = load_main_model_and_processor(
args.model_path,
device,
)
pipe, tokenizers, text_encoders = load_pipe(
model.denoise_tower.denoiser, args.flux_path, device
)
siglip_processor, siglip_model = load_siglip_and_processor(args.siglip_path, device)
# Conversation history
cur_ocr_i = 0
cur_genimg_i = 0
history_image_paths = []
conversation = [
# {"role": "system", "content": "You are a helpful assistant."},
] # list of message dicts: {"role": "system"/"user"/"assistant", "content": [{...}]}
print("Interactive UniWorld-V1 Chat (Exit if input is empty)")
while True:
# Prompt for optional text input
txt = input("Text prompt (or press Enter to skip): ").strip()
# Prompt for multiple image URLs (comma-separated)
img_input = input("Image URLs (comma-separated, or press Enter to skip): ").strip()
# Exit if no input provided
if not img_input and not txt:
print("Exit.")
break
# Build message content list
content = []
if txt:
ocr_sentences = ''
if args.ocr_enhancer:
num_img = len(urls)
ocr_sentences = []
for i in range(num_img):
ocr_sentences.append(get_ocr_result(urls[i], cur_ocr_i))
cur_ocr_i += 1
ocr_sentences = '\n'.join(ocr_sentences)
txt = txt + ocr_sentences
content.append({"type": "text", "text": txt})
new_h, new_w = args.height, args.width
if img_input:
urls = [u.strip() for u in img_input.split(',') if u.strip()]
for url in urls:
content.append({"type": "image", "image": url, "min_pixels": 448*448, "max_pixels": 448*448})
history_image_paths.append(url)
new_h, new_w = update_size(
urls[0] if len(urls) > 0 else None, urls[1] if len(urls) > 1 else None,
'any_11ratio', anchor_pixels=args.height * args.width
)
conversation.append({"role": "user", "content": content})
print('conversation:\n', conversation)
# Prepare inputs for model
chat_text = processor.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
chat_text = '<|im_end|>\n'.join(chat_text.split('<|im_end|>\n')[1:]) # drop system
image_inputs, video_inputs = process_vision_info(conversation)
inputs = processor(
text=[chat_text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
# Generate response
with torch.inference_mode():
outputs = model(**inputs, return_dict=True, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1] # B L D
assistant_mask = inputs.input_ids == 77091
assistant_vectors = hidden_states[assistant_mask][-1:]
task_result = task_head(assistant_vectors.float())[0]
if task_result[0] < task_result[1]:
# if task_result > 0.5:
# gen
siglip_hidden_states = None
if siglip_processor is not None and len(history_image_paths) > 0:
siglip_hidden_states = preprocess_siglip_pixel_values(siglip_model, siglip_processor, history_image_paths)
with torch.no_grad():
lvlm_embeds = model(
inputs.input_ids,
pixel_values=getattr(inputs, 'pixel_values', None),
attention_mask=inputs.attention_mask,
image_grid_thw=getattr(inputs, 'image_grid_thw', None),
siglip_hidden_states=siglip_hidden_states,
output_type="denoise_embeds",
)
assert lvlm_embeds.shape[0] == 1
input_embeds = lvlm_embeds
t5_prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders,
tokenizers,
txt if not args.no_joint_with_t5 else '',
256,
device,
1,
)
if not args.no_joint_with_t5:
input_embeds = torch.concat([t5_prompt_embeds, input_embeds], dim=1)
output_image = pipe(
prompt_embeds=input_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
height=new_h,
width=new_w,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device="cuda").manual_seed(seed),
).images[0]
img_url = generate_image_temp.format(cur_genimg_i)
cur_genimg_i += 1
output_image.save(img_url)
conversation.append({"role": "assistant", "content": [{"type": "image", "image": img_url}]})
history_image_paths.append(img_url)
print(f"Assistant: generate image at {img_url}\n")
else:
# und
generated_ids = model.generate(**inputs, max_new_tokens=128)
# Decode only newly generated tokens
trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
reply = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(f"Assistant: {reply}\n")
# Append assistant response to history
conversation.append({"role": "assistant", "content": [{"type": "text", "text": reply}]})
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Model and component paths")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--flux_path", type=str, required=True)
parser.add_argument("--siglip_path", type=str, required=True)
parser.add_argument("--no_auto_hw", action="store_true")
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--num_inference_steps", type=int, default=28)
parser.add_argument("--guidance_scale", type=float, default=3.5)
parser.add_argument("--ocr_enhancer", action='store_true')
parser.add_argument("--no_joint_with_t5", action="store_true")
args = parser.parse_args()
main(args)