|
|
from typing import List, Optional, Union, Any, Dict |
|
|
|
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers.image_processing_base import BatchFeature |
|
|
from transformers.image_processing_utils_fast import BaseImageProcessorFast, divide_to_patches |
|
|
from transformers.image_utils import (make_list_of_images, get_image_size, |
|
|
get_image_type, ImageInput, ImageType, ChannelDimension) |
|
|
from transformers.utils import TensorType |
|
|
import torchvision.transforms as T |
|
|
|
|
|
|
|
|
|
|
|
class NemotronNanoVLV2ImageProcessor(BaseImageProcessorFast): |
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__(self, image_size=512, max_num_tiles=12, use_thumbnail=True, norm_mean=None, norm_std=None, do_rescale=True, patch_size=16, downsample_ratio=0.5, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.image_size = image_size |
|
|
self.max_num_tiles = max_num_tiles |
|
|
self.use_thumbnail = use_thumbnail |
|
|
self.norm_mean = norm_mean |
|
|
self.norm_std = norm_std |
|
|
self.do_rescale = do_rescale |
|
|
self.num_image_token = int((image_size // patch_size) ** 2 * (downsample_ratio ** 2)) |
|
|
|
|
|
def _process_image( |
|
|
self, |
|
|
image: ImageInput, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
image_type = get_image_type(image) |
|
|
if image_type == ImageType.PIL: |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
image = T.ToTensor()(image) |
|
|
return image |
|
|
|
|
|
def _preprocess( |
|
|
self, |
|
|
images: List[torch.Tensor], |
|
|
image_size: int = None, |
|
|
max_num_tiles: int = None, |
|
|
use_thumbnail: bool = None, |
|
|
do_rescale: bool = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
**kwargs, |
|
|
) -> List[torch.Tensor]: |
|
|
image_size = image_size if image_size is not None else self.image_size |
|
|
max_num_tiles = max_num_tiles if max_num_tiles is not None else self.max_num_tiles |
|
|
use_thumbnail = use_thumbnail if use_thumbnail is not None else self.use_thumbnail |
|
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale |
|
|
|
|
|
images = make_list_of_images(images) |
|
|
|
|
|
all_patches = [] |
|
|
num_patches = [] |
|
|
for image in images: |
|
|
patches = dynamic_preprocess(image, image_size, max_num_tiles, use_thumbnail) |
|
|
all_patches.extend(patches) |
|
|
num_patches.append(len(patches)) |
|
|
|
|
|
pixel_values = torch.stack(all_patches, dim=0) |
|
|
norm_mean = torch.Tensor(self.norm_mean).view(1, 3, 1, 1) |
|
|
norm_std = torch.Tensor(self.norm_std).view(1, 3, 1, 1) |
|
|
pixel_values = (pixel_values - norm_mean) / norm_std |
|
|
return BatchFeature(data={"pixel_values": pixel_values, "num_patches": num_patches}, tensor_type=return_tensors) |
|
|
|
|
|
|
|
|
def get_internvl_target_ratios( |
|
|
min_num: int, |
|
|
max_num: int, |
|
|
) -> list[tuple[int, int]]: |
|
|
target_ratios = {(i, j) |
|
|
for n in range(min_num, max_num + 1) |
|
|
for i in range(1, n + 1) |
|
|
for j in range(1, n + 1) if min_num <= i * j <= max_num} |
|
|
return sorted(target_ratios, key=lambda x: x[0] * x[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def find_closest_aspect_ratio( |
|
|
aspect_ratio: float, |
|
|
target_ratios: list[tuple[int, int]], |
|
|
width: int, |
|
|
height: int, |
|
|
image_size: int, |
|
|
) -> tuple[int, int]: |
|
|
best_ratio_diff = float("inf") |
|
|
best_ratio = (1, 1) |
|
|
area = width * height |
|
|
for ratio in target_ratios: |
|
|
target_aspect_ratio = ratio[0] / ratio[1] |
|
|
ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
|
|
if ratio_diff < best_ratio_diff: |
|
|
best_ratio_diff = ratio_diff |
|
|
best_ratio = ratio |
|
|
elif ratio_diff == best_ratio_diff: |
|
|
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
|
|
best_ratio = ratio |
|
|
return best_ratio |
|
|
|
|
|
|
|
|
def calculate_targets( |
|
|
orig_width: int, |
|
|
orig_height: int, |
|
|
target_ratios: list[tuple[int, int]], |
|
|
image_size: int, |
|
|
) -> tuple[int, int, int]: |
|
|
aspect_ratio = orig_width / orig_height |
|
|
|
|
|
|
|
|
target_aspect_ratio = find_closest_aspect_ratio( |
|
|
aspect_ratio, |
|
|
target_ratios, |
|
|
width=orig_width, |
|
|
height=orig_height, |
|
|
image_size=image_size, |
|
|
) |
|
|
|
|
|
|
|
|
target_width = image_size * target_aspect_ratio[0] |
|
|
target_height = image_size * target_aspect_ratio[1] |
|
|
blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
|
|
|
|
|
return blocks, target_width, target_height |
|
|
|
|
|
|
|
|
def dynamic_preprocess(image, image_size=512, max_num_tiles=12, use_thumbnail=True): |
|
|
orig_height, orig_width = get_image_size(image, channel_dim=ChannelDimension.FIRST) |
|
|
target_ratios = get_internvl_target_ratios(1, max_num_tiles) |
|
|
|
|
|
blocks, target_width, target_height = calculate_targets( |
|
|
orig_width, |
|
|
orig_height, |
|
|
target_ratios, |
|
|
image_size |
|
|
) |
|
|
|
|
|
resized_img = T.Resize((target_height, target_width), interpolation=T.InterpolationMode.BICUBIC)(image) |
|
|
patches = divide_to_patches(resized_img, image_size) |
|
|
assert len(patches) == blocks |
|
|
if use_thumbnail and len(patches) != 1: |
|
|
thumbnail_img = T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)(image) |
|
|
patches.append(thumbnail_img) |
|
|
|
|
|
return patches |
|
|
|