File size: 69,348 Bytes
627f346 5644dea 627f346 5644dea 627f346 5644dea 627f346 a5ce455 627f346 a5ce455 627f346 a5ce455 627f346 a5ce455 627f346 a5ce455 627f346 a5ce455 627f346 a5ce455 627f346 e3bf9ba 627f346 a5ce455 627f346 e3bf9ba 627f346 a5ce455 627f346 a5ce455 627f346 a5ce455 627f346 5644dea 627f346 5644dea 627f346 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 |
# Copyright 2024 Hao Zhang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union, Dict
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import transformers
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
# from .llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
from .modeling_qwen2 import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
# from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
import pdb
import time
import random
random.seed(42)
import torch
from statistics import mean
import torch.nn.functional as F
import PIL
from decord import VideoReader, cpu
from .conversation import conv_templates, SeparatorStyle
from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_TOKEN
from .mm_utils import tokenizer_image_token, load_video, KeywordsStoppingCriteria, get_anyres_image_grid_shape
import math
import re
from .vision_tower_builder import build_vision_tower
from .vision_resampler_builder import build_vision_resampler
from .vision_projector_builder import build_vision_projector
from .utils import rank0_print
from .sae import SiglipAE
import numpy as np
import pdb
from abc import ABC, abstractmethod
class LlavaMetaModel:
def __init__(self, config):
super(LlavaMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
delay_load = getattr(config, "delay_load", False)
self.vision_tower = build_vision_tower(config, delay_load=delay_load)
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
self.hidden_size=config.hidden_size
self.text_mlp=nn.Sequential(
nn.Linear(config.hidden_size,config.hidden_size),
nn.GELU(),
)
self.sae=SiglipAE()
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def initialize_vision_modules(self, model_args, fsdp=None):
vision_tower = model_args.vision_tower
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
mm_patch_merge_type = model_args.mm_patch_merge_type
self.config.mm_vision_tower = vision_tower
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
for k, v in vision_resampler.config.items():
setattr(self.config, k, v)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
self.vision_resampler = [vision_resampler]
else:
self.vision_tower = vision_tower
self.vision_resampler = vision_resampler
else:
if fsdp is not None and len(fsdp) > 0:
vision_resampler = self.vision_resampler[0]
vision_tower = self.vision_tower[0]
else:
vision_resampler = self.vision_resampler
vision_tower = self.vision_tower
vision_tower.load_model()
# In case it is frozen by LoRA
for p in self.vision_resampler.parameters():
p.requires_grad = True
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
self.config.mm_patch_merge_type = mm_patch_merge_type
self.sae=SiglipAE()
self.sae.load_state_dict(torch.load('/share/LXRlxr0_0/code/videoxl2/videoxl2/longva/longva/model/encoder.pth'),strict=False)
if getattr(self, "mm_projector", None) is None:
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
if "unpad" in mm_patch_merge_type:
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
else:
# In case it is frozen by LoRA
for p in self.mm_projector.parameters():
p.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
def get_w(weights, keyword):
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
rank0_print(f"Loaded mm projector weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
rank0_print(f"Loaded vision resampler weights from {pretrain_mm_mlp_adapter}. Incompatible keys: {incompatible_keys}")
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
# Compute aspect ratios
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
# Determine padding size and direction
if original_aspect_ratio > current_aspect_ratio:
# Padding was added to the height
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
# Padding was added to the width
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
class LlavaMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def get_2dPool(self, image_feature):
height = width = self.get_vision_tower().num_patches_per_side
num_frames, num_tokens, num_dim = image_feature.shape
image_feature = image_feature.view(num_frames, height, width, -1)
image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
# image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
if self.config.mm_spatial_pool_mode == "average":
image_feature = nn.functional.avg_pool2d(image_feature, self.config.mm_spatial_pool_stride)
elif self.config.mm_spatial_pool_mode == "max":
image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
else:
raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
image_feature = image_feature.permute(0, 2, 3, 1)
image_feature = image_feature.view(num_frames, -1, num_dim)
return image_feature
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
#image_features = self.get_model().vision_resampler(image_features, images=images)
image_features = self.get_model().mm_projector(image_features)
image_features = self.get_model().vision_resampler(image_features, images=images)
return image_features
def add_image(self, image_features):
return torch.repeat_interleave(image_features, repeats=4, dim=0)
def add_video(self, video_features):
# Current batch size
current_batch_size = video_features.size(0)
# Handle cases where the batch size is less than 4
if current_batch_size < 4:
last_feature = video_features[-1:]
# Calculate how many times the last feature needs to be repeated
num_repeats = 4 - current_batch_size
repeated_features = last_feature.repeat(num_repeats, 1, 1, 1)
# Concatenate original features with repeated last feature
expanded_x = torch.cat([video_features, repeated_features], dim=0)
return expanded_x
# Handle cases where the batch size is 4 or greater, but not a multiple of 4
if current_batch_size % 4 != 0:
last_feature = video_features[-1:]
# Calculate how many features are needed to reach the next multiple of 4
padding_size = 4 - (current_batch_size % 4)
repeated_features = last_feature.repeat(padding_size, 1, 1, 1)
# Concatenate original features with repeated last feature
expanded_x = torch.cat([video_features, repeated_features], dim=0)
return expanded_x
# If the batch size is already a multiple of 4, return as is
return video_features
def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
if self.config.enable_chunk_prefill:
chunk_size_for_vision_tower = self.config.prefill_config['chunk_size_for_vision_tower']
else:
chunk_size_for_vision_tower = 100000
# pdb.set_trace()
# Define the maximum batch size (1024 frames)
max_batch_size = chunk_size_for_vision_tower
# print(f'max_batch_size: {max_batch_size}')
num_frames = videos_or_images.shape[0]
# Initialize a list to store the features from each batch
videos_or_images_features = []
videos_or_images_features = torch.empty((num_frames, 729, 1152), device=self.get_model().device, dtype=self.get_model().dtype)
# Split videos_or_images into smaller batches if num_frames > max_batch_size
current_idx = 0
if num_frames > max_batch_size:
# Calculate the number of batches needed
num_batches = (num_frames + max_batch_size - 1) // max_batch_size
for i in range(num_batches):
start_idx = i * max_batch_size
end_idx = min((i + 1) * max_batch_size, num_frames)
# Process each batch separately
batch_videos_or_images = videos_or_images[start_idx:end_idx]
batch_features = self.get_model().get_vision_tower()(batch_videos_or_images)
# videos_or_images_features.append(batch_features)
videos_or_images_features[current_idx:current_idx + batch_features.shape[0]] = batch_features
# Update the current index for the next batch
current_idx += batch_features.shape[0]
# peak_memory_allocated = torch.cuda.max_memory_allocated()
# print(f"vision encoder 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
# Concatenate the features of all batches
# videos_or_images_features = torch.cat(videos_or_images_features, dim=0)
else:
videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0)
all_videos_or_images_features = []
# peak_memory_allocated = torch.cuda.max_memory_allocated()
# print(f"vision encoder 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
del videos_or_images_features
torch.cuda.empty_cache()
chunk_size = chunk_size_for_vision_tower
# print(f'chunk_size: {chunk_size}')
all_feat_list = []
for idx, feat in enumerate(per_videos_or_images_features):
for i in range(0, feat.shape[0], chunk_size):
batched_feat = feat[i:i+chunk_size] # chunk_size = 48, batched_feat.shape=[48, 729, 1152]
batched_feat=self.interpolate(batched_feat) # 插值后 batched_feat.shape=[48, 1152, 24, 24]
if idx in video_idx_in_batch:
batched_feat = self.add_video(batched_feat) # 第一纬度补充到4的倍数
else:
batched_feat = self.add_image(batched_feat)
bc,ch,h,w = batched_feat.shape
batched_feat = batched_feat.view(bc//4,ch,4,h,w)
batched_feat = self.get_model().sae(batched_feat).squeeze(2)
batched_feat = batched_feat.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
batched_feat = self.get_model().mm_projector(batched_feat)
batched_feat = self.get_2dPool(batched_feat)
all_feat_list.append(batched_feat)
feat = torch.cat(all_feat_list, dim=0)
# peak_memory_allocated = torch.cuda.max_memory_allocated()
# print(f"sae 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
del per_videos_or_images_features
del all_feat_list
torch.cuda.empty_cache()
all_videos_or_images_features.append(feat)
return all_videos_or_images_features
def interpolate(self,image_features):
b, num_tokens, dim = image_features.shape
#print(str(image_features.shape)+' i\n')
target_h = target_w = int(576**0.5)
h = w = int(num_tokens**0.5)
image_features = image_features.view(b, h, w, dim)
image_features = image_features.permute(0, 3, 1, 2).contiguous()
chunk_size = 24
chunks = torch.split(image_features, chunk_size, dim=0)
interpolated_chunks = []
for chunk in chunks:
interpolated_chunk = F.interpolate(
chunk.to(torch.float32),
size=(target_h, target_w),
mode="bilinear",
align_corners=False,
).to(chunk.dtype)
interpolated_chunks.append(interpolated_chunk)
image_features = torch.cat(interpolated_chunks, dim=0)
del interpolated_chunks
del chunks
return image_features
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None,time_embedding=None):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
video_idx_in_batch = []
for _ in range(len(modalities)):
if modalities[_] == "video":
video_idx_in_batch.append(_)
images_list = []
for image in images:
if image.ndim == 4:
images_list.append(image)
else:
images_list.append(image.unsqueeze(0))
#print(len(images_list),images_list[0].shape)
concat_images = torch.cat([image for image in images_list], dim=0)
split_sizes = [image.shape[0] for image in images_list]
image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) #16,144,3584
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
visual_drop_score=[]
new_image_features=[]
if mm_patch_merge_type == "flat":
if image_features[0].ndim>2:
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type== "unires":
#print('unires')
for image_idx, image_feature in enumerate(image_features):
# rank0_print(f"Initial feature size : {image_feature.shape}")
if image_idx in video_idx_in_batch: # video operations
#print(image_feature.shape)
image_feature = image_feature.flatten(0, 1)
elif image_feature.shape[0] > 1:
# base image feature is never used in unires
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]
kernel_size = mm_patch_merge_type.split("avgpool")[-1].split("x")[-1]
kernel_size = 2
image_feature = image_feature.view(image_feature.shape[0], height, width, -1) # [4, 24, 24, 4096]
image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # [4, 4096, 24, 24]
image_feature = nn.functional.avg_pool2d(image_feature,kernel_size) # [4, 4096, 12, 12]
image_feature = image_feature.flatten(2, 3) # [4, 4096, 144]
image_feature = image_feature.permute(0, 2, 1).contiguous() # [4, 144, 4096]
#print(image_feature.shape)
image_feature = image_feature.flatten(0, 1)
else:
image_feature = image_feature[0]
new_image_features.append(image_feature)
image_features = new_image_features
elif mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
# FIXME: now assume the image is square, and split to 2x2 patches
# num_patches = h * w, where h = w = sqrt(num_patches)
# currently image_feature is a tensor of shape (4, num_patches, hidden_size)
# we want to first unflatten it to (2, 2, h, w, hidden_size)
if image_idx in video_idx_in_batch: # video operations
if "unpad" in mm_patch_merge_type:
# image_feature = image_feature.permute(2, 0, 1).contiguous()
# image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
# image_feature = image_feature.permute(1, 2, 0).contiguous()
image_feature = image_feature.flatten(0, 1)
image_feature = torch.cat((image_feature, self.model.image_newline[None].to(image_feature.device)), dim=0)
elif image_feature.shape[0] > 1: # multi patches and multi images operations
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
if hasattr(self.get_vision_tower(), "image_size"):
vision_tower_image_size = self.get_vision_tower().image_size
else:
raise ValueError("vision_tower_image_size is not found in the vision tower.")
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
image_feature = image_feature.view(2, 2, height, width, -1)
if "maxpool2x2" in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = nn.functional.max_pool2d(image_feature, 2)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif "unpad" in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
if "nobase" in mm_patch_merge_type:
pass
else:
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else: # single image operations
image_feature = image_feature[0]
if "unpad" in mm_patch_merge_type:
image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
else:
error_message = """
Something is wrong with the input shape. Most likely, you did not wrap the image or video input in a list:
This is correct:
model.generate(input_ids, images=[video_tensor], modalities=["video"], **gen_kwargs)
model.generate(input_ids, images=[image_tensor], modalities=["image"], **gen_kwargs)
This is wrong:
model.generate(input_ids, images=video_tensor, modalities=["video"], **gen_kwargs)
model.generate(input_ids, images=image_tensor, modalities=["image"], **gen_kwargs)
"""
raise ValueError(error_message)
#print(time_embedding[0].shape)
#video_token_indices=[]
for image_idx, image_feature in enumerate(image_features):
if time_embedding[image_idx] is not None:
mask = (time_embedding[image_idx] == 151654)
indices = torch.nonzero(mask).squeeze()
embed_token=self.get_model().embed_tokens(time_embedding[image_idx])
embed_token[indices]=image_features[image_idx]
#video_token_indices.append(indices)
image_features[image_idx]=embed_token
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
#print(num_images)
if num_images>=2:
print(num_images,input_ids)
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
#print(image_token_indices) #[-1, 14, 236]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
# print(cur_input_ids)
# print(labels[batch_idx])
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
#print(torch.cat(cur_input_ids_noim).shape,torch.cat(cur_input_ids_noim))
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
##############
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
# import pdb; pdb.set_trace()
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
# NOTE: qmh
# new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
# new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
# TODO: Hard code for control loss spike
# if tokenizer_model_max_length is not None:
# new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
# new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
if getattr(self.config, "use_pos_skipping", False) and self.training:
position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
split_position = random.randint(0, new_input_embeds.size(1))
left_add = random.randint(0, self.config.pos_skipping_range)
right_add = random.randint(left_add, self.config.pos_skipping_range)
position_ids[:, :split_position] += left_add
position_ids[:, split_position:] += right_add
# import pdb; pdb.set_trace()
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location="cpu")
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
class LlavaQwenConfig(Qwen2Config):
model_type = "llava_qwen"
class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
config_class = LlavaQwenConfig
def __init__(self, config: Qwen2Config):
super(LlavaQwenModel, self).__init__(config)
class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaQwenConfig
def __init__(self, config):
# super(Qwen2ForCausalLM, self).__init__(config)
Qwen2ForCausalLM.__init__(self, config)
config.model_type = "llava_qwen"
config.rope_scaling = None
self.model = LlavaQwenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def uniform_sampling(self, embeds, start_idx, end_idx, step):
indices = torch.arange(start_idx, end_idx, step).to(device=embeds.device)
return embeds.index_select(1, indices), indices
def pooling_sampling(self, embeds, start_idx, end_idx, step, pool_type='avg'):
selected = embeds[:, start_idx:end_idx, :]
B, D, L = selected.shape
kernel_size = step
stride = step
selected_transposed = selected.transpose(1, 2) # shape: (1, 12, 4)
if pool_type == 'avg_pool':
pooled = F.avg_pool1d(selected_transposed, kernel_size=kernel_size, stride=stride)
elif pool_type == 'max_pool':
pooled = F.max_pool1d(selected_transposed, kernel_size=kernel_size, stride=stride)
else:
raise ValueError(f"Unsupported pooling type: {pool_type}")
pooled = pooled.transpose(1, 2) # shape: (1, 2, 12)
return pooled, torch.arange(start_idx, start_idx + pooled.shape[1] * step, step).to(device=embeds.device)
def process_block(self, block_embeds, current_past_key_values=None, bsz=1, device=None, position_ids=None, key_position_ids=None):
if current_past_key_values is None:
seq_len = block_embeds.size(1)
position_ids = torch.arange(0, seq_len, device=device).expand(bsz, -1)
attention_mask = torch.ones((bsz, seq_len), device=device, dtype=torch.long)
else:
seq_len = block_embeds.size(1)
prefix_len = current_past_key_values[0][0].size(2)
attention_mask = torch.ones((bsz, prefix_len + seq_len), device=device, dtype=torch.long)
outputs = self.model(
inputs_embeds=block_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
key_position_ids=key_position_ids,
past_key_values=current_past_key_values,
use_cache=True,
return_dict=True,
)
return outputs.past_key_values
def pooling_kvs(self, kvs, step):
# kvs shape: (bsz, 4, seq_len, head_dim)
kernel_size = step
stride = step
# kvs = kvs.transpose(2, 3)
# pooled_kvs = F.avg_pool1d(kvs, kernel_size=kernel_size, stride=stride)
kvs_permuted = kvs.permute(0, 1, 3, 2) # (batch_size, num_heads, feature_dim, sequence_length)
N_flat = kvs_permuted.shape[0] * kvs_permuted.shape[1]
C = kvs_permuted.shape[2]
L = kvs_permuted.shape[3]
kvs_for_pool = kvs_permuted.reshape(N_flat, C, L)
pooled_kvs = F.avg_pool1d(kvs_for_pool, kernel_size=kernel_size, stride=stride)
pooled_kvs_restored = pooled_kvs.view(kvs.shape[0], kvs.shape[1], pooled_kvs.shape[1], pooled_kvs.shape[2]).permute(0, 1, 3, 2)
return pooled_kvs_restored
def get_sparse_attention_mask(self, total_len, num_blocks, block_size, time_token_start_indices, time_token_end_indices, time_token_indices, visual_token_start_pos, visual_token_end_pos, attention_mask, inputs_embeds, prev_blocks_num=None):
causal_mask = torch.tril(torch.ones((total_len, total_len), dtype=torch.bool)).unsqueeze(0).repeat(1, 1, 1)
mask = torch.zeros(total_len, total_len, dtype=torch.bool)
start = visual_token_start_pos
record_block_start = []
for i in range(num_blocks):
next_time_token_pos = (i + 1)*block_size
if next_time_token_pos >= len(time_token_start_indices):
end = visual_token_end_pos
else:
end = time_token_start_indices[ next_time_token_pos ]
mask[start:end, start:end] = True
if len(record_block_start) >= prev_blocks_num:
prev_start = record_block_start[-prev_blocks_num]
else:
prev_start = visual_token_start_pos
mask[start:end, prev_start:start] = True
record_block_start.append(start)
start = end
mask[:, :visual_token_start_pos] = True
mask[visual_token_end_pos:, :] = True
for idx in time_token_indices:
mask[idx, :] = True
mask[:, idx] = True
causal_mask = torch.tril(torch.ones(total_len, total_len, dtype=torch.bool))
final_mask = (mask & causal_mask).unsqueeze(0).unsqueeze(0).to(dtype=attention_mask.dtype, device=attention_mask.device)
num_allowed = final_mask.sum().item()
upper_triangle_num = total_len * (total_len + 1) // 2
ratio = num_allowed / upper_triangle_num
invert_mask = 1.0 - final_mask
final_mask = ((1.0 - final_mask) * -1e9).to(dtype=inputs_embeds.dtype)
return final_mask, ratio
def cat_history_kvs(self, prefix_kvs, kvs_part2, kvs_part3):
prefix_kvs = [[kvs] for kvs in prefix_kvs]
cat_kvs = []
for prefix_kvs_this_layer, kvs_part2_this_layer, kvs_part3_this_layer in zip(prefix_kvs, kvs_part2, kvs_part3):
prefix_key_this_layer = [tmp[0] for tmp in prefix_kvs_this_layer]
prefix_val_this_layer = [tmp[1] for tmp in prefix_kvs_this_layer]
key_part2_this_layer = [tmp[0] for tmp in kvs_part2_this_layer]
val_part2_this_layer = [tmp[1] for tmp in kvs_part2_this_layer]
key_part3_this_layer = [tmp[0] for tmp in kvs_part3_this_layer]
val_part3_this_layer = [tmp[1] for tmp in kvs_part3_this_layer]
key_this_layer = torch.cat(prefix_key_this_layer + key_part2_this_layer + key_part3_this_layer, dim=-2)
val_this_layer = torch.cat(prefix_val_this_layer + val_part2_this_layer + val_part3_this_layer, dim=-2)
cat_kvs.append((key_this_layer, val_this_layer))
return cat_kvs
def forward_streaming(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
key_position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
dpo_forward: Optional[bool] = False,
cache_position=None,
visual_token_start_pos=None,
visual_token_end_pos=None,
time_token_start_indices=None,
frames_num=None,
time_token_indices=None,
time_token_end_indices=None,
block_size_chosed=None,
prev_blocks_num=None,
offload: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
block_size = block_size_chosed
visual_token_start_pos = visual_token_start_pos
visual_token_end_pos = visual_token_end_pos
visual_len = visual_token_end_pos - visual_token_start_pos
num_blocks = (frames_num + block_size * 4 - 1) // (block_size * 4)
# streaming inps
blocks_positions = [[(0, 0, visual_token_start_pos)]]
frames_groups = [(0, visual_token_start_pos)]
for idx, (time_start, time_end) in enumerate(zip(time_token_start_indices, time_token_end_indices)):
if idx + 1 < len(time_token_start_indices):
frames_group_end = time_token_start_indices[idx + 1]
else:
frames_group_end = visual_token_end_pos
frames_groups.append(
(time_start, time_end, frames_group_end)
)
single_block = []
for group in frames_groups[1:]:
single_block.append(group)
if len(single_block) == block_size:
blocks_positions.append(single_block)
single_block = []
if len(single_block) != 0:
blocks_positions.append(single_block)
num_blocks = len(blocks_positions)
start = time.time()
record_prefill_time = 0
full_inputs_embeds = inputs_embeds
bsz, total_len, embed_dim = full_inputs_embeds.size()
device = full_inputs_embeds.device
prefix_embeds = full_inputs_embeds[:, :visual_token_start_pos, :]
visual_embeds = full_inputs_embeds[:, visual_token_start_pos:visual_token_end_pos, :]
suffix_embeds = full_inputs_embeds[:, visual_token_end_pos:, :]
num_visual_tokens = visual_embeds.size(1)
all_past_key_values = [[] for _ in range(len(self.model.layers))]
prefix_past_key_values = []
# torch.cuda.reset_peak_memory_stats()
if prefix_embeds.size(1) > 0:
pkv = self.process_block(prefix_embeds, bsz=bsz, device=device)
for i in range(len(pkv)):
all_past_key_values[i].append(pkv[i])
prefix_past_key_values.append(pkv[i])
prev_blocks = blocks_positions[1:1+prev_blocks_num]
prev_the_first_block = prev_blocks[0]
prev_b_start = prev_the_first_block[0][0]
prev_the_last_block = prev_blocks[-1]
prev_b_end = prev_the_last_block[-1][-1]
block_streaming_past_key_values = prefix_past_key_values
query_position_ids = torch.arange(prev_b_start, prev_b_end, dtype=torch.long, device=device)
past_key_position_ids = torch.arange(0, block_streaming_past_key_values[0][0].size(2), dtype=torch.long, device=device)
key_position_ids = torch.cat([past_key_position_ids, query_position_ids], dim=0)
visual_embeds_this_block = full_inputs_embeds[:,prev_b_start:prev_b_end,:]
pkv = self.process_block(visual_embeds_this_block, current_past_key_values=block_streaming_past_key_values, bsz=bsz, device=device, position_ids=query_position_ids.unsqueeze(0), key_position_ids=key_position_ids.unsqueeze(0))
for i in range(len(pkv)):
for block in prev_blocks:
block_start, _, _ = block[0]
_, _, block_end = block[-1]
all_past_key_values[i].append( (pkv[i][0][:,:,block_start:block_end], pkv[i][1][:,:,block_start:block_end]) )
block_streaming_past_key_values_part1 = prefix_past_key_values
position_ids_part1 = torch.arange(0, prefix_past_key_values[0][0].size(2), dtype=torch.long, device=device)
block_streaming_past_key_values_part2 = [[] for _ in range(len(self.model.layers))]
position_ids_part2 = torch.tensor([], dtype=torch.long, device=device)
block_streaming_past_key_values_part3=None
position_ids_part3 = None
query_position_ids = None
for idx, single_block in enumerate(blocks_positions[:]):
if idx == 0 or idx <= prev_blocks_num:
continue
b_start, _, _ = single_block[0]
_, _, b_end = single_block[-1]
visual_embeds_this_block = full_inputs_embeds[:,b_start:b_end,:]
prev_blocks = blocks_positions[max(idx - prev_blocks_num, 1):idx]
prev_the_first_block = prev_blocks[0]
prev_b_start = prev_the_first_block[0][0]
this_block_length = b_end - prev_b_start
prev_block_length = b_start - prev_b_start
true_block_length = b_end - b_start
block_streaming_past_key_values_part3 = [tmp[-prev_blocks_num:] for tmp in all_past_key_values]
if offload:
block_streaming_past_key_values_part3 = [
[
(t[0].to(device=device), t[1].to(device=device))
for t in sublist
]
for sublist in block_streaming_past_key_values_part3
]
block_streaming_past_key_values = self.cat_history_kvs(block_streaming_past_key_values_part1, block_streaming_past_key_values_part2, block_streaming_past_key_values_part3)
query_position_ids = torch.arange(b_start, b_end, dtype=torch.long, device=device)
position_ids_part3 = torch.arange(prev_b_start, b_start, dtype=torch.long, device=device)
key_position_ids = torch.cat([position_ids_part1, position_ids_part2, position_ids_part3, query_position_ids], dim=0)
start_1 = time.time()
pkv = self.process_block(visual_embeds_this_block, current_past_key_values=block_streaming_past_key_values, bsz=bsz, device=device, position_ids=query_position_ids.unsqueeze(0), key_position_ids=key_position_ids.unsqueeze(0))
end_1 = time.time()
record_prefill_time += end_1-start_1
for i in range(len(pkv)):
length_before_chunk = block_streaming_past_key_values[i][0].size(2)
key_this_block, val_this_block = pkv[i]
key_this_block = key_this_block[:,:,length_before_chunk:,:]
val_this_block = val_this_block[:,:,length_before_chunk:,:]
if offload:
all_past_key_values[i].append( (key_this_block.to('cpu'), val_this_block.to('cpu')) )
else:
all_past_key_values[i].append( (key_this_block, val_this_block) )
time_keys_list = []
time_vals_list = []
extract_timestamps_position_ids_list = []
for group in prev_the_first_block:
time_start, time_end, _ = group
extract_timestamps_position_ids_list.append(torch.arange(time_start, time_end, dtype=torch.long, device=device))
time_start = time_start - prev_b_start
time_end = time_end - prev_b_start
time_keys_list.append(block_streaming_past_key_values_part3[i][0][0][:,:,time_start:time_end,:])
time_vals_list.append(block_streaming_past_key_values_part3[i][0][1][:,:,time_start:time_end,:])
time_keys = torch.cat(time_keys_list, dim=2)
time_vals = torch.cat(time_vals_list, dim=2)
block_streaming_past_key_values_part2[i].append( (time_keys, time_vals) )
if i == 0:
position_ids_part2 = torch.cat([position_ids_part2] + extract_timestamps_position_ids_list, dim=0)
merged_pkv = []
for layer_pkvs in all_past_key_values:
if not layer_pkvs:
continue
keys = torch.cat([pkv[0].to(device=device) for pkv in layer_pkvs], dim=2) # dim=2 是 sequence 维度
values = torch.cat([pkv[1].to(device=device) for pkv in layer_pkvs], dim=2)
merged_pkv.append((keys, values))
# peak_memory_allocated = torch.cuda.max_memory_allocated()
# print(f"prefill 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
pkv = merged_pkv
del block_streaming_past_key_values
del all_past_key_values
del block_streaming_past_key_values_part1
del block_streaming_past_key_values_part2
del block_streaming_past_key_values_part3
torch.cuda.empty_cache()
# TODO: bi-decoding acceleration
mixed_prefill_past_key_values = pkv
prefill_len = visual_token_end_pos
# torch.cuda.reset_peak_memory_stats()
# Process suffix
if suffix_embeds.size(1) > 0:
seq_len = suffix_embeds.size(1)
total_len = prefill_len + seq_len
position_ids = torch.arange(prefill_len, total_len, device=device, dtype=torch.long).expand(bsz, -1)
key_position_ids = torch.arange(0, total_len, device=device, dtype=torch.long).expand(bsz, -1)
attention_mask = torch.ones((bsz, total_len), device=device, dtype=torch.long)
outputs = super().forward(
inputs_embeds=suffix_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
key_position_ids=key_position_ids,
past_key_values=mixed_prefill_past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=True,
return_dict=return_dict,
# blocks_positions=None,
)
# peak_memory_allocated = torch.cuda.max_memory_allocated()
# print(f"decoding 显存峰值: {peak_memory_allocated / (1024**3):.2f} GB") # 转换为GB
del mixed_prefill_past_key_values
torch.cuda.empty_cache()
return outputs
def forward_mask(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
dpo_forward: Optional[bool] = False,
cache_position=None,
visual_token_start_pos=None,
visual_token_end_pos=None,
time_token_start_indices=None,
time_token_end_indices=None,
frames_num=None,
time_token_indices=None,
prev_blocks_num=None,
block_size_chosed=None
) -> Union[Tuple, CausalLMOutputWithPast]:
bsz, total_len, embed_dim = inputs_embeds.size()
visual_token_start_pos = visual_token_start_pos
visual_token_end_pos = visual_token_end_pos
visual_len = visual_token_end_pos - visual_token_start_pos
block_size_list = [2,4,8,16,32]
best_block_size = None
min_diff = float('inf')
block_size = block_size_chosed
num_blocks = (frames_num + block_size * 4 - 1) // (block_size * 4)
final_mask, ratio = self.get_sparse_attention_mask(total_len, num_blocks, block_size, time_token_start_indices, time_token_end_indices, time_token_indices, visual_token_start_pos, visual_token_end_pos, attention_mask, inputs_embeds, prev_blocks_num)
# print(f'frames:{frames_num}, block_num:{num_blocks}, bsz:{block_size}, prev_blocks_num:{prev_blocks_num}, ratio:{ratio}')
return super().forward(
input_ids=input_ids,
attention_mask=final_mask, # final_mask
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
key_position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
modalities: Optional[List[str]] = ["image"],
dpo_forward: Optional[bool] = False,
cache_position=None,
time_embedding=None,
visual_token_start_pos=None,
visual_token_end_pos=None,
time_token_start_indices=None,
frames_num=None,
time_token_indices=None,
time_token_end_indices=None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if input_ids is not None and input_ids.size(1) == 1:
past_key_len = past_key_values[0][0].size(-2)
key_position_ids = torch.arange(0, past_key_len+1, device=position_ids.device,dtype=torch.long).expand(1, -1)
if position_ids[0][0] != past_key_len:
position_ids = torch.tensor([[past_key_len]]).to(device=position_ids.device, dtype=position_ids.dtype)
key_position_ids = torch.arange(0, past_key_len+1, device=position_ids.device,dtype=torch.long).expand(1, -1)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
key_position_ids=key_position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if inputs_embeds is None:
(input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes, time_embedding)
if self.config.enable_chunk_prefill:
prefill_mode = self.config.prefill_config['chunk_prefill_mode']
chunk_size = self.config.prefill_config['chunk_size']
step_size = self.config.prefill_config['step_size']
offload = self.config.prefill_config['offload']
if prefill_mode=='streaming':
return self.forward_streaming(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
key_position_ids=key_position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
visual_token_start_pos=visual_token_start_pos,
visual_token_end_pos=visual_token_end_pos,
time_token_start_indices=time_token_start_indices,
frames_num=frames_num,
time_token_indices=time_token_indices,
time_token_end_indices=time_token_end_indices,
block_size_chosed=chunk_size,
prev_blocks_num=chunk_size - step_size,
offload=offload,
)
elif prefill_mode=='mask':
return self.forward_mask(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
visual_token_start_pos=visual_token_start_pos,
visual_token_end_pos=visual_token_end_pos,
time_token_start_indices=time_token_start_indices,
frames_num=frames_num,
time_token_indices=time_token_indices,
time_token_end_indices=time_token_end_indices,
block_size_chosed=block_size_chosed,
prev_blocks_num=prev_blocks_num,
)
else:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
modalities: Optional[List[str]] = ["image"],
time_embedding=None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None and images[0].size(0) > 0:
IMAGE_TOKEN_INDEX = -200
TOKEN_PERFRAME = 36
frames_num = images[0].size(0)
visual_token_start_pos = (inputs == IMAGE_TOKEN_INDEX).nonzero(as_tuple=True)[1].item()
num_tokens = time_embedding[0].size(0)
visual_token_end_pos = visual_token_start_pos + num_tokens
kwargs['visual_token_start_pos'] = visual_token_start_pos
kwargs['visual_token_end_pos'] = visual_token_end_pos
# time_token_start_indices = (time_embedding[0] == 1462).nonzero(as_tuple=True)
time_token_start_indices = (time_embedding[0] == 1462).nonzero(as_tuple=True)[0].cpu().tolist()
kwargs['time_token_start_indices'] = [idx + visual_token_start_pos for idx in time_token_start_indices]
# kwargs['time_token_start_indices'] = time_token_start_indices + visual_token_start_pos
kwargs['frames_num'] = frames_num
time_token_indices = (time_embedding[0] != 151654).nonzero(as_tuple=True)[0].cpu().tolist()
kwargs['time_token_indices'] = [idx + visual_token_start_pos for idx in time_token_indices]
time_token_end_indices = (time_embedding[0] == 25).nonzero(as_tuple=True)[0].cpu().tolist()
kwargs['time_token_end_indices'] = [idx + visual_token_start_pos + 1 for idx in time_token_end_indices]
# kwargs['time_token_end_indices'] = time_token_end_indices + visual_token_start_pos
#print(images[0].shape)
if images is not None:
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes,time_embedding=time_embedding)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
#print(inputs_embeds.shape)
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
@torch.no_grad()
def chat(self,
video_path,
tokenizer,
user_prompt,
chat_history=None,
return_history=True,
max_num_frames=512,
sample_fps=1,
max_sample_fps=4,
generation_config={}):
# prepare text input
conv = conv_templates["qwen_1_5"].copy()
if chat_history is None or len(chat_history) == 0:
user_prompt = f'{DEFAULT_IMAGE_TOKEN}\n{user_prompt}'
else:
assert DEFAULT_IMAGE_TOKEN in chat_history[0]['content'], chat_history
for msg in chat_history:
conv.append_message(msg['role'], msg['content'])
conv.append_message(conv.roles[0], user_prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
generation_config["stopping_criteria"] = [stopping_criteria]
# prepare video input
frames, timestamps = load_video(video_path, max_num_frames, fps=sample_fps, max_fps=max_sample_fps)
print(f'video has loaded, extract {len(frames)} frames.')
time_stamps=[]
token_frames_sum=(len(timestamps)+3)//4
compress_frame = timestamps[::4]
time_embedding = []
for time in compress_frame:
item = f"Time {time}s:"
time_embedding.append(tokenizer(item).input_ids)
time_embedding.append([151654]*144)
time_embedding = [item for sublist in time_embedding for item in sublist]
time_embedding = torch.tensor(time_embedding, dtype=torch.long).to(self.model.device)
time_stamps.append(time_embedding)
video_tensor = self.get_vision_tower().image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(self.model.device, dtype=torch.float16)
with torch.inference_mode():
output_ids = self.generate(input_ids, images=[video_tensor],time_embedding=time_stamps, modalities=["video"], **generation_config)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
if chat_history is None:
chat_history = []
chat_history.append({"role":conv.roles[0], "content":user_prompt})
chat_history.append({"role":conv.roles[1], "content":outputs})
if return_history:
return outputs, chat_history
else:
return outputs
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
visual_token_start_pos = kwargs.get("visual_token_start_pos", None)
visual_token_end_pos = kwargs.get("visual_token_end_pos", None)
time_token_start_indices = kwargs.get("time_token_start_indices", None)
frames_num = kwargs.get("frames_num", None)
time_token_indices = kwargs.get("time_token_indices", None)
time_token_end_indices = kwargs.get("time_token_end_indices", None)
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
inputs["visual_token_start_pos"] = visual_token_start_pos
inputs["visual_token_end_pos"] = visual_token_end_pos
inputs["time_token_start_indices"] = time_token_start_indices
inputs["frames_num"] = frames_num
inputs["time_token_indices"] = time_token_indices
inputs["time_token_end_indices"] = time_token_end_indices
if images is not None:
inputs["images"] = images
if image_sizes is not None:
inputs["image_sizes"] = image_sizes
return inputs
AutoConfig.register("llava_qwen", LlavaQwenConfig)
AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenConfig) |