Upload 18 files
Browse files- modeling_vlm.py +83 -1
modeling_vlm.py
CHANGED
|
@@ -27,12 +27,14 @@ from transformers import (
|
|
| 27 |
PreTrainedModel,
|
| 28 |
GenerationMixin
|
| 29 |
)
|
|
|
|
| 30 |
from transformers.configuration_utils import PretrainedConfig
|
| 31 |
|
| 32 |
from .clip_encoder import CLIPVisionTower
|
| 33 |
from .siglip_vit import create_siglip_vit
|
| 34 |
from .projector import MlpProjector
|
| 35 |
from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
class vision_head(torch.nn.Module):
|
|
@@ -61,7 +63,7 @@ def model_name_to_cls(cls_name):
|
|
| 61 |
cls = CLIPVisionTower
|
| 62 |
|
| 63 |
elif "VQ" in cls_name:
|
| 64 |
-
from
|
| 65 |
|
| 66 |
cls = VQ_models[cls_name]
|
| 67 |
elif "vision_head" in cls_name:
|
|
@@ -193,7 +195,87 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
| 193 |
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
|
| 194 |
return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
|
|
|
| 197 |
|
| 198 |
|
| 199 |
AutoConfig.register("vision", VisionConfig)
|
|
|
|
| 27 |
PreTrainedModel,
|
| 28 |
GenerationMixin
|
| 29 |
)
|
| 30 |
+
import numpy as np
|
| 31 |
from transformers.configuration_utils import PretrainedConfig
|
| 32 |
|
| 33 |
from .clip_encoder import CLIPVisionTower
|
| 34 |
from .siglip_vit import create_siglip_vit
|
| 35 |
from .projector import MlpProjector
|
| 36 |
from .configuration_vlm import AttrDict, MultiModalityConfig, VisionConfig, AlignerConfig, GenVisionConfig, GenHeadConfig, GenAlignerConfig
|
| 37 |
+
from .vq_model import VQ_models
|
| 38 |
|
| 39 |
|
| 40 |
class vision_head(torch.nn.Module):
|
|
|
|
| 63 |
cls = CLIPVisionTower
|
| 64 |
|
| 65 |
elif "VQ" in cls_name:
|
| 66 |
+
from .vq_model import VQ_models
|
| 67 |
|
| 68 |
cls = VQ_models[cls_name]
|
| 69 |
elif "vision_head" in cls_name:
|
|
|
|
| 195 |
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask, **kwargs)
|
| 196 |
return self.language_model.generate(inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs)
|
| 197 |
|
| 198 |
+
@torch.no_grad()
|
| 199 |
+
def generate_image(
|
| 200 |
+
self,
|
| 201 |
+
processor,
|
| 202 |
+
prompt: str,
|
| 203 |
+
temperature: float = 1,
|
| 204 |
+
parallel_size: int = 16,
|
| 205 |
+
cfg_weight: float = 5,
|
| 206 |
+
image_token_num_per_image: int = 576,
|
| 207 |
+
img_size: int = 384,
|
| 208 |
+
patch_size: int = 16,
|
| 209 |
+
generator=None
|
| 210 |
+
):
|
| 211 |
+
from PIL import Image
|
| 212 |
+
|
| 213 |
+
conversation = [
|
| 214 |
+
{
|
| 215 |
+
"role": "User",
|
| 216 |
+
"content": prompt,
|
| 217 |
+
},
|
| 218 |
+
{"role": "Assistant", "content": ""},
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
sft_format = processor.apply_sft_template_for_multi_turn_prompts(
|
| 222 |
+
conversations=conversation,
|
| 223 |
+
sft_format=processor.sft_format,
|
| 224 |
+
system_prompt="",
|
| 225 |
+
)
|
| 226 |
+
prompt = sft_format + processor.image_start_tag
|
| 227 |
+
input_ids = processor.tokenizer.encode(prompt)
|
| 228 |
+
input_ids = torch.LongTensor(input_ids)
|
| 229 |
+
|
| 230 |
+
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int)
|
| 231 |
+
for i in range(parallel_size * 2):
|
| 232 |
+
tokens[i, :] = input_ids
|
| 233 |
+
if i % 2 != 0:
|
| 234 |
+
tokens[i, 1:-1] = processor.pad_id
|
| 235 |
+
|
| 236 |
+
inputs_embeds = self.language_model.get_input_embeddings()(tokens)
|
| 237 |
+
|
| 238 |
+
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int)
|
| 239 |
+
past_key_values = None
|
| 240 |
+
|
| 241 |
+
for i in range(image_token_num_per_image):
|
| 242 |
+
outputs = self.language_model.model.forward(
|
| 243 |
+
input_ids=None,
|
| 244 |
+
inputs_embeds=inputs_embeds,
|
| 245 |
+
use_cache=True,
|
| 246 |
+
past_key_values=past_key_values,
|
| 247 |
+
)
|
| 248 |
+
hidden_states = outputs.last_hidden_state
|
| 249 |
+
past_key_values = outputs.past_key_values
|
| 250 |
+
logits = self.gen_head(hidden_states[:, -1, :])
|
| 251 |
+
logit_cond = logits[0::2, :]
|
| 252 |
+
logit_uncond = logits[1::2, :]
|
| 253 |
+
|
| 254 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
| 255 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 256 |
+
|
| 257 |
+
next_token = torch.multinomial(probs, num_samples=1) if generator is None else torch.multinomial(probs, num_samples=1, generator=generator)
|
| 258 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
| 259 |
+
|
| 260 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 261 |
+
img_embeds = self.prepare_gen_img_embeds(next_token)
|
| 262 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
| 263 |
+
dec = self.gen_vision_model.decode_code(
|
| 264 |
+
generated_tokens.to(dtype=torch.int), [parallel_size, 8, img_size // patch_size, img_size // patch_size]
|
| 265 |
+
)
|
| 266 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
| 267 |
+
|
| 268 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
| 269 |
+
|
| 270 |
+
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
|
| 271 |
+
visual_img[:, :, :] = dec
|
| 272 |
+
|
| 273 |
+
images = []
|
| 274 |
+
|
| 275 |
+
for i in range(parallel_size):
|
| 276 |
+
images.append(Image.fromarray(visual_img[i]))
|
| 277 |
|
| 278 |
+
return images
|
| 279 |
|
| 280 |
|
| 281 |
AutoConfig.register("vision", VisionConfig)
|