|
import gc |
|
import math |
|
import os |
|
import re |
|
import warnings |
|
from fractions import Fraction |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from ..utils import _log_api_usage_once |
|
from . import _video_opt |
|
from ._video_deprecation_warning import _raise_video_deprecation_warning |
|
|
|
try: |
|
import av |
|
|
|
av.logging.set_level(av.logging.ERROR) |
|
if not hasattr(av.video.frame.VideoFrame, "pict_type"): |
|
av = ImportError( |
|
"""\ |
|
Your version of PyAV is too old for the necessary video operations in torchvision. |
|
If you are on Python 3.5, you will have to build from source (the conda-forge |
|
packages are not up-to-date). See |
|
https://github.com/mikeboers/PyAV#installation for instructions on how to |
|
install PyAV on your system. |
|
""" |
|
) |
|
try: |
|
FFmpegError = av.FFmpegError |
|
except AttributeError: |
|
FFmpegError = av.AVError |
|
except ImportError: |
|
av = ImportError( |
|
"""\ |
|
PyAV is not installed, and is necessary for the video operations in torchvision. |
|
See https://github.com/mikeboers/PyAV#installation for instructions on how to |
|
install PyAV on your system. |
|
""" |
|
) |
|
|
|
|
|
def _check_av_available() -> None: |
|
if isinstance(av, Exception): |
|
raise av |
|
|
|
|
|
def _av_available() -> bool: |
|
return not isinstance(av, Exception) |
|
|
|
|
|
|
|
_CALLED_TIMES = 0 |
|
_GC_COLLECTION_INTERVAL = 10 |
|
|
|
|
|
def write_video( |
|
filename: str, |
|
video_array: torch.Tensor, |
|
fps: float, |
|
video_codec: str = "libx264", |
|
options: Optional[Dict[str, Any]] = None, |
|
audio_array: Optional[torch.Tensor] = None, |
|
audio_fps: Optional[float] = None, |
|
audio_codec: Optional[str] = None, |
|
audio_options: Optional[Dict[str, Any]] = None, |
|
) -> None: |
|
""" |
|
[DEPRECATED] Writes a 4d tensor in [T, H, W, C] format in a video file. |
|
|
|
.. warning:: |
|
|
|
DEPRECATED: All the video decoding and encoding capabilities of torchvision |
|
are deprecated from version 0.22 and will be removed in version 0.24. We |
|
recommend that you migrate to |
|
`TorchCodec <https://github.com/pytorch/torchcodec>`__, where we'll |
|
consolidate the future decoding/encoding capabilities of PyTorch |
|
|
|
This function relies on PyAV (therefore, ultimately FFmpeg) to encode |
|
videos, you can get more fine-grained control by referring to the other |
|
options at your disposal within `the FFMpeg wiki |
|
<http://trac.ffmpeg.org/wiki#Encoding>`_. |
|
|
|
Args: |
|
filename (str): path where the video will be saved |
|
video_array (Tensor[T, H, W, C]): tensor containing the individual frames, |
|
as a uint8 tensor in [T, H, W, C] format |
|
fps (Number): video frames per second |
|
video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc. |
|
options (Dict): dictionary containing options to be passed into the PyAV video stream. |
|
The list of options is codec-dependent and can all |
|
be found from `the FFMpeg wiki <http://trac.ffmpeg.org/wiki#Encoding>`_. |
|
audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels |
|
and N is the number of samples |
|
audio_fps (Number): audio sample rate, typically 44100 or 48000 |
|
audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc. |
|
audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream. |
|
The list of options is codec-dependent and can all |
|
be found from `the FFMpeg wiki <http://trac.ffmpeg.org/wiki#Encoding>`_. |
|
|
|
Examples:: |
|
>>> # Creating libx264 video with CRF 17, for visually lossless footage: |
|
>>> |
|
>>> from torchvision.io import write_video |
|
>>> # 1000 frames of 100x100, 3-channel image. |
|
>>> vid = torch.randn(1000, 100, 100, 3, dtype = torch.uint8) |
|
>>> write_video("video.mp4", options = {"crf": "17"}) |
|
|
|
""" |
|
_raise_video_deprecation_warning() |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(write_video) |
|
_check_av_available() |
|
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy(force=True) |
|
|
|
|
|
|
|
if isinstance(fps, float): |
|
fps = np.round(fps) |
|
|
|
with av.open(filename, mode="w") as container: |
|
stream = container.add_stream(video_codec, rate=fps) |
|
stream.width = video_array.shape[2] |
|
stream.height = video_array.shape[1] |
|
stream.pix_fmt = "yuv420p" if video_codec != "libx264rgb" else "rgb24" |
|
stream.options = options or {} |
|
|
|
if audio_array is not None: |
|
audio_format_dtypes = { |
|
"dbl": "<f8", |
|
"dblp": "<f8", |
|
"flt": "<f4", |
|
"fltp": "<f4", |
|
"s16": "<i2", |
|
"s16p": "<i2", |
|
"s32": "<i4", |
|
"s32p": "<i4", |
|
"u8": "u1", |
|
"u8p": "u1", |
|
} |
|
a_stream = container.add_stream(audio_codec, rate=audio_fps) |
|
a_stream.options = audio_options or {} |
|
|
|
num_channels = audio_array.shape[0] |
|
audio_layout = "stereo" if num_channels > 1 else "mono" |
|
audio_sample_fmt = container.streams.audio[0].format.name |
|
|
|
format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt]) |
|
audio_array = torch.as_tensor(audio_array).numpy(force=True).astype(format_dtype) |
|
|
|
frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout) |
|
|
|
frame.sample_rate = audio_fps |
|
|
|
for packet in a_stream.encode(frame): |
|
container.mux(packet) |
|
|
|
for packet in a_stream.encode(): |
|
container.mux(packet) |
|
|
|
for img in video_array: |
|
frame = av.VideoFrame.from_ndarray(img, format="rgb24") |
|
try: |
|
frame.pict_type = "NONE" |
|
except TypeError: |
|
from av.video.frame import PictureType |
|
|
|
frame.pict_type = PictureType.NONE |
|
|
|
for packet in stream.encode(frame): |
|
container.mux(packet) |
|
|
|
|
|
for packet in stream.encode(): |
|
container.mux(packet) |
|
|
|
|
|
def _read_from_stream( |
|
container: "av.container.Container", |
|
start_offset: float, |
|
end_offset: float, |
|
pts_unit: str, |
|
stream: "av.stream.Stream", |
|
stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]], |
|
) -> List["av.frame.Frame"]: |
|
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL |
|
_CALLED_TIMES += 1 |
|
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: |
|
gc.collect() |
|
|
|
if pts_unit == "sec": |
|
|
|
|
|
start_offset = int(math.floor(start_offset * (1 / stream.time_base))) |
|
if end_offset != float("inf"): |
|
end_offset = int(math.ceil(end_offset * (1 / stream.time_base))) |
|
else: |
|
warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.") |
|
|
|
frames = {} |
|
should_buffer = True |
|
max_buffer_size = 5 |
|
if stream.type == "video": |
|
|
|
|
|
|
|
extradata = stream.codec_context.extradata |
|
|
|
|
|
if extradata and b"DivX" in extradata: |
|
|
|
pos = extradata.find(b"DivX") |
|
d = extradata[pos:] |
|
o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d) |
|
if o is None: |
|
o = re.search(rb"DivX(\d+)b(\d+)(\w)", d) |
|
if o is not None: |
|
should_buffer = o.group(3) == b"p" |
|
seek_offset = start_offset |
|
|
|
seek_offset = max(seek_offset - 1, 0) |
|
if should_buffer: |
|
|
|
|
|
seek_offset = max(seek_offset - max_buffer_size, 0) |
|
try: |
|
|
|
container.seek(seek_offset, any_frame=False, backward=True, stream=stream) |
|
except FFmpegError: |
|
|
|
|
|
return [] |
|
buffer_count = 0 |
|
try: |
|
for _idx, frame in enumerate(container.decode(**stream_name)): |
|
frames[frame.pts] = frame |
|
if frame.pts >= end_offset: |
|
if should_buffer and buffer_count < max_buffer_size: |
|
buffer_count += 1 |
|
continue |
|
break |
|
except FFmpegError: |
|
|
|
pass |
|
|
|
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset] |
|
if len(frames) > 0 and start_offset > 0 and start_offset not in frames: |
|
|
|
|
|
|
|
preceding_frames = [i for i in frames if i < start_offset] |
|
if len(preceding_frames) > 0: |
|
first_frame_pts = max(preceding_frames) |
|
result.insert(0, frames[first_frame_pts]) |
|
return result |
|
|
|
|
|
def _align_audio_frames( |
|
aframes: torch.Tensor, audio_frames: List["av.frame.Frame"], ref_start: int, ref_end: float |
|
) -> torch.Tensor: |
|
start, end = audio_frames[0].pts, audio_frames[-1].pts |
|
total_aframes = aframes.shape[1] |
|
step_per_aframe = (end - start + 1) / total_aframes |
|
s_idx = 0 |
|
e_idx = total_aframes |
|
if start < ref_start: |
|
s_idx = int((ref_start - start) / step_per_aframe) |
|
if end > ref_end: |
|
e_idx = int((ref_end - end) / step_per_aframe) |
|
return aframes[:, s_idx:e_idx] |
|
|
|
|
|
def read_video( |
|
filename: str, |
|
start_pts: Union[float, Fraction] = 0, |
|
end_pts: Optional[Union[float, Fraction]] = None, |
|
pts_unit: str = "pts", |
|
output_format: str = "THWC", |
|
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: |
|
"""[DEPRECATED] Reads a video from a file, returning both the video frames and the audio frames |
|
|
|
.. warning:: |
|
|
|
DEPRECATED: All the video decoding and encoding capabilities of torchvision |
|
are deprecated from version 0.22 and will be removed in version 0.24. We |
|
recommend that you migrate to |
|
`TorchCodec <https://github.com/pytorch/torchcodec>`__, where we'll |
|
consolidate the future decoding/encoding capabilities of PyTorch |
|
|
|
Args: |
|
filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts. |
|
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): |
|
The start presentation time of the video |
|
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): |
|
The end presentation time |
|
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, |
|
either 'pts' or 'sec'. Defaults to 'pts'. |
|
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". |
|
|
|
Returns: |
|
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames |
|
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points |
|
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) |
|
""" |
|
_raise_video_deprecation_warning() |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(read_video) |
|
|
|
output_format = output_format.upper() |
|
if output_format not in ("THWC", "TCHW"): |
|
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") |
|
|
|
from torchvision import get_video_backend |
|
|
|
if get_video_backend() != "pyav": |
|
if not os.path.exists(filename): |
|
raise RuntimeError(f"File not found: {filename}") |
|
vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit) |
|
else: |
|
_check_av_available() |
|
|
|
if end_pts is None: |
|
end_pts = float("inf") |
|
|
|
if end_pts < start_pts: |
|
raise ValueError( |
|
f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}" |
|
) |
|
|
|
info = {} |
|
video_frames = [] |
|
audio_frames = [] |
|
audio_timebase = _video_opt.default_timebase |
|
|
|
try: |
|
with av.open(filename, metadata_errors="ignore") as container: |
|
if container.streams.audio: |
|
audio_timebase = container.streams.audio[0].time_base |
|
if container.streams.video: |
|
video_frames = _read_from_stream( |
|
container, |
|
start_pts, |
|
end_pts, |
|
pts_unit, |
|
container.streams.video[0], |
|
{"video": 0}, |
|
) |
|
video_fps = container.streams.video[0].average_rate |
|
|
|
if video_fps is not None: |
|
info["video_fps"] = float(video_fps) |
|
|
|
if container.streams.audio: |
|
audio_frames = _read_from_stream( |
|
container, |
|
start_pts, |
|
end_pts, |
|
pts_unit, |
|
container.streams.audio[0], |
|
{"audio": 0}, |
|
) |
|
info["audio_fps"] = container.streams.audio[0].rate |
|
|
|
except FFmpegError: |
|
|
|
pass |
|
|
|
vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames] |
|
aframes_list = [frame.to_ndarray() for frame in audio_frames] |
|
|
|
if vframes_list: |
|
vframes = torch.as_tensor(np.stack(vframes_list)) |
|
else: |
|
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) |
|
|
|
if aframes_list: |
|
aframes = np.concatenate(aframes_list, 1) |
|
aframes = torch.as_tensor(aframes) |
|
if pts_unit == "sec": |
|
start_pts = int(math.floor(start_pts * (1 / audio_timebase))) |
|
if end_pts != float("inf"): |
|
end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) |
|
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) |
|
else: |
|
aframes = torch.empty((1, 0), dtype=torch.float32) |
|
|
|
if output_format == "TCHW": |
|
|
|
vframes = vframes.permute(0, 3, 1, 2) |
|
|
|
return vframes, aframes, info |
|
|
|
|
|
def _can_read_timestamps_from_packets(container: "av.container.Container") -> bool: |
|
extradata = container.streams[0].codec_context.extradata |
|
if extradata is None: |
|
return False |
|
if b"Lavc" in extradata: |
|
return True |
|
return False |
|
|
|
|
|
def _decode_video_timestamps(container: "av.container.Container") -> List[int]: |
|
if _can_read_timestamps_from_packets(container): |
|
|
|
return [x.pts for x in container.demux(video=0) if x.pts is not None] |
|
else: |
|
return [x.pts for x in container.decode(video=0) if x.pts is not None] |
|
|
|
|
|
def read_video_timestamps(filename: str, pts_unit: str = "pts") -> Tuple[List[int], Optional[float]]: |
|
"""[DEPREACTED] List the video frames timestamps. |
|
|
|
.. warning:: |
|
|
|
DEPRECATED: All the video decoding and encoding capabilities of torchvision |
|
are deprecated from version 0.22 and will be removed in version 0.24. We |
|
recommend that you migrate to |
|
`TorchCodec <https://github.com/pytorch/torchcodec>`__, where we'll |
|
consolidate the future decoding/encoding capabilities of PyTorch |
|
|
|
Note that the function decodes the whole video frame-by-frame. |
|
|
|
Args: |
|
filename (str): path to the video file |
|
pts_unit (str, optional): unit in which timestamp values will be returned |
|
either 'pts' or 'sec'. Defaults to 'pts'. |
|
|
|
Returns: |
|
pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'): |
|
presentation timestamps for each one of the frames in the video. |
|
video_fps (float, optional): the frame rate for the video |
|
|
|
""" |
|
_raise_video_deprecation_warning() |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(read_video_timestamps) |
|
from torchvision import get_video_backend |
|
|
|
if get_video_backend() != "pyav": |
|
return _video_opt._read_video_timestamps(filename, pts_unit) |
|
|
|
_check_av_available() |
|
|
|
video_fps = None |
|
pts = [] |
|
|
|
try: |
|
with av.open(filename, metadata_errors="ignore") as container: |
|
if container.streams.video: |
|
video_stream = container.streams.video[0] |
|
video_time_base = video_stream.time_base |
|
try: |
|
pts = _decode_video_timestamps(container) |
|
except FFmpegError: |
|
warnings.warn(f"Failed decoding frames for file {filename}") |
|
video_fps = float(video_stream.average_rate) |
|
except FFmpegError as e: |
|
msg = f"Failed to open container for {filename}; Caught error: {e}" |
|
warnings.warn(msg, RuntimeWarning) |
|
|
|
pts.sort() |
|
|
|
if pts_unit == "sec": |
|
pts = [x * video_time_base for x in pts] |
|
|
|
return pts, video_fps |
|
|