|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for eliminating boilerplate code to handle abstract streams with |
|
CPU device. |
|
""" |
|
from contextlib import contextmanager |
|
from typing import Generator, List, Optional, Union, cast |
|
|
|
import torch |
|
|
|
__all__: List[str] = [] |
|
|
|
|
|
class CPUStreamType: |
|
pass |
|
|
|
|
|
|
|
CPUStream = CPUStreamType() |
|
|
|
|
|
AbstractStream = Union[torch.cuda.Stream, CPUStreamType] |
|
|
|
|
|
def new_stream(device: torch.device) -> AbstractStream: |
|
"""Creates a new stream for either CPU or CUDA device.""" |
|
if device.type != "cuda": |
|
return CPUStream |
|
return torch.cuda.Stream(device) |
|
|
|
|
|
def current_stream(device: torch.device) -> AbstractStream: |
|
""":func:`torch.cuda.current_stream` for either CPU or CUDA device.""" |
|
if device.type != "cuda": |
|
return CPUStream |
|
return torch.cuda.current_stream(device) |
|
|
|
|
|
def default_stream(device: torch.device) -> AbstractStream: |
|
""":func:`torch.cuda.default_stream` for either CPU or CUDA device.""" |
|
if device.type != "cuda": |
|
return CPUStream |
|
return torch.cuda.default_stream(device) |
|
|
|
|
|
@contextmanager |
|
def use_device(device: torch.device) -> Generator[None, None, None]: |
|
""":func:`torch.cuda.device` for either CPU or CUDA device.""" |
|
if device.type != "cuda": |
|
yield |
|
return |
|
|
|
with torch.cuda.device(device): |
|
yield |
|
|
|
|
|
@contextmanager |
|
def use_stream(stream: Optional[AbstractStream]) -> Generator[None, None, None]: |
|
""":func:`torch.cuda.stream` for either CPU or CUDA stream.""" |
|
if not stream: |
|
yield |
|
return |
|
|
|
if not is_cuda(stream): |
|
yield |
|
return |
|
|
|
with torch.cuda.stream(as_cuda(stream)): |
|
yield |
|
|
|
|
|
def get_device(stream: AbstractStream) -> torch.device: |
|
"""Gets the device from CPU or CUDA stream.""" |
|
if is_cuda(stream): |
|
return as_cuda(stream).device |
|
return torch.device("cpu") |
|
|
|
|
|
def wait_stream(source: AbstractStream, target: AbstractStream) -> None: |
|
""":meth:`torch.cuda.Stream.wait_stream` for either CPU or CUDA stream. It |
|
makes the source stream wait until the target stream completes work queued. |
|
""" |
|
if is_cuda(target): |
|
if is_cuda(source): |
|
|
|
as_cuda(source).wait_stream(as_cuda(target)) |
|
else: |
|
|
|
as_cuda(target).synchronize() |
|
|
|
|
|
|
|
|
|
def record_stream(tensor: torch.Tensor, stream: AbstractStream) -> None: |
|
""":meth:`torch.Tensor.record_stream` for either CPU or CUDA stream.""" |
|
if is_cuda(stream): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tensor = tensor.new_empty([0]).set_(tensor.storage()) |
|
|
|
tensor.record_stream(as_cuda(stream)) |
|
|
|
|
|
def is_cuda(stream: Optional[AbstractStream]) -> bool: |
|
"""Returns ``True`` if the given stream is a valid CUDA stream.""" |
|
return stream is not CPUStream |
|
|
|
|
|
def as_cuda(stream: AbstractStream) -> torch.cuda.Stream: |
|
"""Casts the given stream as :class:`torch.cuda.Stream`.""" |
|
return cast(torch.cuda.Stream, stream) |
|
|