File size: 14,846 Bytes
e0be88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import is_torch_available, is_vision_available
from transformers.image_processing_utils import get_size_dict
from transformers.image_utils import SizeDict
from transformers.processing_utils import VideosKwargs
from transformers.testing_utils import (
require_av,
require_cv2,
require_decord,
require_torch,
require_torchvision,
require_vision,
)
from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos
if is_torch_available():
import torch
if is_vision_available():
import PIL
from transformers import BaseVideoProcessor
from transformers.video_utils import VideoMetadata, load_video
def get_random_video(height, width, num_frames=8, return_torch=False):
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
video = np.array(([random_frame] * num_frames))
if return_torch:
# move channel first
return torch.from_numpy(video).permute(0, 3, 1, 2)
return video
@require_vision
@require_torchvision
class BaseVideoProcessorTester(unittest.TestCase):
"""
Tests that the `transforms` can be applied to a 4-dim array directly, i.e. to a whole video.
"""
def test_make_batched_videos_pil(self):
# Test a single image is converted to a list of 1 video with 1 frame
video = get_random_video(16, 32)
pil_image = PIL.Image.fromarray(video[0])
videos_list = make_batched_videos(pil_image)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0][0], np.array(pil_image)))
# Test a list of videos is converted to a list of 1 video
video = get_random_video(16, 32)
pil_video = [PIL.Image.fromarray(frame) for frame in video]
videos_list = make_batched_videos(pil_video)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
# Test a nested list of videos is not modified
video = get_random_video(16, 32)
pil_video = [PIL.Image.fromarray(frame) for frame in video]
videos = [pil_video, pil_video]
videos_list = make_batched_videos(videos)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
def test_make_batched_videos_numpy(self):
# Test a single image is converted to a list of 1 video with 1 frame
video = get_random_video(16, 32)[0]
videos_list = make_batched_videos(video)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0][0], video))
# Test a 4d array of videos is converted to a a list of 1 video
video = get_random_video(16, 32)
videos_list = make_batched_videos(video)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
# Test a list of videos is converted to a list of videos
video = get_random_video(16, 32)
videos = [video, video]
videos_list = make_batched_videos(videos)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
@require_torch
def test_make_batched_videos_torch(self):
# Test a single image is converted to a list of 1 video with 1 frame
video = get_random_video(16, 32)[0]
torch_video = torch.from_numpy(video)
videos_list = make_batched_videos(torch_video)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], np.ndarray)
self.assertEqual(videos_list[0].shape, (1, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0][0], video))
# Test a 4d array of videos is converted to a a list of 1 video
video = get_random_video(16, 32)
torch_video = torch.from_numpy(video)
videos_list = make_batched_videos(torch_video)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], torch.Tensor)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
# Test a list of videos is converted to a list of videos
video = get_random_video(16, 32)
torch_video = torch.from_numpy(video)
videos = [torch_video, torch_video]
videos_list = make_batched_videos(videos)
self.assertIsInstance(videos_list, list)
self.assertIsInstance(videos_list[0], torch.Tensor)
self.assertEqual(videos_list[0].shape, (8, 16, 32, 3))
self.assertTrue(np.array_equal(videos_list[0], video))
def test_resize(self):
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
video = get_random_video(16, 32, return_torch=True)
# Size can be an int or a tuple of ints.
size_dict = SizeDict(**get_size_dict((8, 8), param_name="size"))
resized_video = video_processor.resize(video, size=size_dict)
self.assertIsInstance(resized_video, torch.Tensor)
self.assertEqual(resized_video.shape, (8, 3, 8, 8))
def test_normalize(self):
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
array = torch.randn(4, 3, 16, 32)
mean = [0.1, 0.5, 0.9]
std = [0.2, 0.4, 0.6]
# mean and std can be passed as lists or NumPy arrays.
expected = (array - torch.tensor(mean)[:, None, None]) / torch.tensor(std)[:, None, None]
normalized_array = video_processor.normalize(array, mean, std)
torch.testing.assert_close(normalized_array, expected)
def test_center_crop(self):
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
video = get_random_video(16, 32, return_torch=True)
# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
size_dict = SizeDict(**get_size_dict(size, default_to_square=True, param_name="crop_size"))
cropped_video = video_processor.center_crop(video, size_dict)
self.assertIsInstance(cropped_video, torch.Tensor)
expected_size = (size, size) if isinstance(size, int) else size
self.assertEqual(cropped_video.shape, (8, 3, *expected_size))
def test_convert_to_rgb(self):
video_processor = BaseVideoProcessor(model_init_kwargs=VideosKwargs)
video = get_random_video(20, 20, return_torch=True)
rgb_video = video_processor.convert_to_rgb(video[:, :1])
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
def test_group_and_reorder_videos(self):
"""Tests that videos can be grouped by frame size and number of frames"""
video_1 = get_random_video(20, 20, num_frames=3, return_torch=True)
video_2 = get_random_video(20, 20, num_frames=5, return_torch=True)
# Group two videos of same size but different number of frames
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group two videos of different size but same number of frames
video_3 = get_random_video(15, 20, num_frames=3, return_torch=True)
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group all three videos where some have same size or same frame count
# But since none have frames and sizes identical, we'll have 3 groups
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3])
self.assertEqual(len(grouped_videos), 3)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 3)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group if we had some videos with identical shapes
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group if we had all videos with identical shapes
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1])
self.assertEqual(len(grouped_videos), 1)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 1)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
@require_vision
@require_av
class LoadVideoTester(unittest.TestCase):
def test_load_video_url(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
)
self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
def test_load_video_local(self):
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
video, _ = load_video(video_file_path)
self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
# FIXME: @raushan, yt-dlp downloading works for for some reason it cannot redirect to out buffer?
# @requires_yt_dlp
# def test_load_video_youtube(self):
# video = load_video("https://www.youtube.com/watch?v=QC8iQqtG0hg")
# self.assertEqual(video.shape, (243, 360, 640, 3)) # 243 frames is the whole video, no sampling applied
@require_decord
@require_torchvision
@require_cv2
def test_load_video_backend_url(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
backend="decord",
)
self.assertEqual(video.shape, (243, 360, 640, 3))
# Can't use certain backends with url
with self.assertRaises(ValueError):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
backend="opencv",
)
with self.assertRaises(ValueError):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
backend="torchvision",
)
@require_decord
@require_torchvision
@require_cv2
def test_load_video_backend_local(self):
video_file_path = hf_hub_download(
repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset"
)
video, metadata = load_video(video_file_path, backend="decord")
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)
video, metadata = load_video(video_file_path, backend="opencv")
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)
video, metadata = load_video(video_file_path, backend="torchvision")
self.assertEqual(video.shape, (243, 360, 640, 3))
self.assertIsInstance(metadata, VideoMetadata)
def test_load_video_num_frames(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
num_frames=16,
)
self.assertEqual(video.shape, (16, 360, 640, 3))
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
num_frames=22,
)
self.assertEqual(video.shape, (22, 360, 640, 3))
def test_load_video_fps(self):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", fps=1
)
self.assertEqual(video.shape, (9, 360, 640, 3))
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4", fps=2
)
self.assertEqual(video.shape, (19, 360, 640, 3))
# `num_frames` is mutually exclusive with `video_fps`
with self.assertRaises(ValueError):
video, _ = load_video(
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
fps=1,
num_frames=10,
)
|