IceClear
upload files
42f2c22
from abc import ABC, abstractmethod
from typing import Sequence, Union
import torch
from ..types import SamplingDirection
class Timesteps(ABC):
"""
Timesteps base class.
"""
def __init__(self, T: Union[int, float]):
assert T > 0
self._T = T
@property
def T(self) -> Union[int, float]:
"""
Maximum timestep inclusive.
int if discrete, float if continuous.
"""
return self._T
def is_continuous(self) -> bool:
"""
Whether the schedule is continuous.
"""
return isinstance(self.T, float)
class SamplingTimesteps(Timesteps):
"""
Sampling timesteps.
It defines the discretization of sampling steps.
"""
def __init__(
self,
T: Union[int, float],
timesteps: torch.Tensor,
direction: SamplingDirection,
):
assert timesteps.ndim == 1
super().__init__(T)
self.timesteps = timesteps
self.direction = direction
def __len__(self) -> int:
"""
Number of sampling steps.
"""
return len(self.timesteps)
def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor:
"""
The timestep at the sampling step.
Returns a scalar tensor if idx is int,
or tensor of the same size if idx is a tensor.
"""
return self.timesteps[idx]
def index(self, t: torch.Tensor) -> torch.Tensor:
"""
Find index by t.
Return index of the same shape as t.
Index is -1 if t not found in timesteps.
"""
i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True)
idx = torch.full_like(t, fill_value=-1, dtype=torch.int)
idx.view(-1)[i] = j.int()
return idx