|
|
|
|
|
from typing import Any, Callable, Optional |
|
|
|
|
|
r""" |
|
The following constraints are implemented: |
|
|
|
- ``constraints.boolean`` |
|
- ``constraints.cat`` |
|
- ``constraints.corr_cholesky`` |
|
- ``constraints.dependent`` |
|
- ``constraints.greater_than(lower_bound)`` |
|
- ``constraints.greater_than_eq(lower_bound)`` |
|
- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` |
|
- ``constraints.integer_interval(lower_bound, upper_bound)`` |
|
- ``constraints.interval(lower_bound, upper_bound)`` |
|
- ``constraints.less_than(upper_bound)`` |
|
- ``constraints.lower_cholesky`` |
|
- ``constraints.lower_triangular`` |
|
- ``constraints.multinomial`` |
|
- ``constraints.nonnegative`` |
|
- ``constraints.nonnegative_integer`` |
|
- ``constraints.one_hot`` |
|
- ``constraints.positive_integer`` |
|
- ``constraints.positive`` |
|
- ``constraints.positive_semidefinite`` |
|
- ``constraints.positive_definite`` |
|
- ``constraints.real_vector`` |
|
- ``constraints.real`` |
|
- ``constraints.simplex`` |
|
- ``constraints.symmetric`` |
|
- ``constraints.stack`` |
|
- ``constraints.square`` |
|
- ``constraints.symmetric`` |
|
- ``constraints.unit_interval`` |
|
""" |
|
|
|
import torch |
|
|
|
|
|
__all__ = [ |
|
"Constraint", |
|
"boolean", |
|
"cat", |
|
"corr_cholesky", |
|
"dependent", |
|
"dependent_property", |
|
"greater_than", |
|
"greater_than_eq", |
|
"independent", |
|
"integer_interval", |
|
"interval", |
|
"half_open_interval", |
|
"is_dependent", |
|
"less_than", |
|
"lower_cholesky", |
|
"lower_triangular", |
|
"multinomial", |
|
"nonnegative", |
|
"nonnegative_integer", |
|
"one_hot", |
|
"positive", |
|
"positive_semidefinite", |
|
"positive_definite", |
|
"positive_integer", |
|
"real", |
|
"real_vector", |
|
"simplex", |
|
"square", |
|
"stack", |
|
"symmetric", |
|
"unit_interval", |
|
] |
|
|
|
|
|
class Constraint: |
|
""" |
|
Abstract base class for constraints. |
|
|
|
A constraint object represents a region over which a variable is valid, |
|
e.g. within which a variable can be optimized. |
|
|
|
Attributes: |
|
is_discrete (bool): Whether constrained space is discrete. |
|
Defaults to False. |
|
event_dim (int): Number of rightmost dimensions that together define |
|
an event. The :meth:`check` method will remove this many dimensions |
|
when computing validity. |
|
""" |
|
|
|
is_discrete = False |
|
event_dim = 0 |
|
|
|
def check(self, value): |
|
""" |
|
Returns a byte tensor of ``sample_shape + batch_shape`` indicating |
|
whether each event in value satisfies this constraint. |
|
""" |
|
raise NotImplementedError |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__[1:] + "()" |
|
|
|
|
|
class _Dependent(Constraint): |
|
""" |
|
Placeholder for variables whose support depends on other variables. |
|
These variables obey no simple coordinate-wise constraints. |
|
|
|
Args: |
|
is_discrete (bool): Optional value of ``.is_discrete`` in case this |
|
can be computed statically. If not provided, access to the |
|
``.is_discrete`` attribute will raise a NotImplementedError. |
|
event_dim (int): Optional value of ``.event_dim`` in case this |
|
can be computed statically. If not provided, access to the |
|
``.event_dim`` attribute will raise a NotImplementedError. |
|
""" |
|
|
|
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): |
|
self._is_discrete = is_discrete |
|
self._event_dim = event_dim |
|
super().__init__() |
|
|
|
@property |
|
def is_discrete(self) -> bool: |
|
if self._is_discrete is NotImplemented: |
|
raise NotImplementedError(".is_discrete cannot be determined statically") |
|
return self._is_discrete |
|
|
|
@property |
|
def event_dim(self) -> int: |
|
if self._event_dim is NotImplemented: |
|
raise NotImplementedError(".event_dim cannot be determined statically") |
|
return self._event_dim |
|
|
|
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): |
|
""" |
|
Support for syntax to customize static attributes:: |
|
|
|
constraints.dependent(is_discrete=True, event_dim=1) |
|
""" |
|
if is_discrete is NotImplemented: |
|
is_discrete = self._is_discrete |
|
if event_dim is NotImplemented: |
|
event_dim = self._event_dim |
|
return _Dependent(is_discrete=is_discrete, event_dim=event_dim) |
|
|
|
def check(self, x): |
|
raise ValueError("Cannot determine validity of dependent constraint") |
|
|
|
|
|
def is_dependent(constraint): |
|
""" |
|
Checks if ``constraint`` is a ``_Dependent`` object. |
|
|
|
Args: |
|
constraint : A ``Constraint`` object. |
|
|
|
Returns: |
|
``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise. |
|
|
|
Examples: |
|
>>> import torch |
|
>>> from torch.distributions import Bernoulli |
|
>>> from torch.distributions.constraints import is_dependent |
|
|
|
>>> dist = Bernoulli(probs=torch.tensor([0.6], requires_grad=True)) |
|
>>> constraint1 = dist.arg_constraints["probs"] |
|
>>> constraint2 = dist.arg_constraints["logits"] |
|
|
|
>>> for constraint in [constraint1, constraint2]: |
|
>>> if is_dependent(constraint): |
|
>>> continue |
|
""" |
|
return isinstance(constraint, _Dependent) |
|
|
|
|
|
class _DependentProperty(property, _Dependent): |
|
""" |
|
Decorator that extends @property to act like a `Dependent` constraint when |
|
called on a class and act like a property when called on an object. |
|
|
|
Example:: |
|
|
|
class Uniform(Distribution): |
|
def __init__(self, low, high): |
|
self.low = low |
|
self.high = high |
|
|
|
@constraints.dependent_property(is_discrete=False, event_dim=0) |
|
def support(self): |
|
return constraints.interval(self.low, self.high) |
|
|
|
Args: |
|
fn (Callable): The function to be decorated. |
|
is_discrete (bool): Optional value of ``.is_discrete`` in case this |
|
can be computed statically. If not provided, access to the |
|
``.is_discrete`` attribute will raise a NotImplementedError. |
|
event_dim (int): Optional value of ``.event_dim`` in case this |
|
can be computed statically. If not provided, access to the |
|
``.event_dim`` attribute will raise a NotImplementedError. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
fn: Optional[Callable[..., Any]] = None, |
|
*, |
|
is_discrete: Optional[bool] = NotImplemented, |
|
event_dim: Optional[int] = NotImplemented, |
|
) -> None: |
|
super().__init__(fn) |
|
self._is_discrete = is_discrete |
|
self._event_dim = event_dim |
|
|
|
def __call__(self, fn: Callable[..., Any]) -> "_DependentProperty": |
|
""" |
|
Support for syntax to customize static attributes:: |
|
|
|
@constraints.dependent_property(is_discrete=True, event_dim=1) |
|
def support(self): ... |
|
""" |
|
return _DependentProperty( |
|
fn, is_discrete=self._is_discrete, event_dim=self._event_dim |
|
) |
|
|
|
|
|
class _IndependentConstraint(Constraint): |
|
""" |
|
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many |
|
dims in :meth:`check`, so that an event is valid only if all its |
|
independent entries are valid. |
|
""" |
|
|
|
def __init__(self, base_constraint, reinterpreted_batch_ndims): |
|
assert isinstance(base_constraint, Constraint) |
|
assert isinstance(reinterpreted_batch_ndims, int) |
|
assert reinterpreted_batch_ndims >= 0 |
|
self.base_constraint = base_constraint |
|
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims |
|
super().__init__() |
|
|
|
@property |
|
def is_discrete(self) -> bool: |
|
return self.base_constraint.is_discrete |
|
|
|
@property |
|
def event_dim(self) -> int: |
|
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims |
|
|
|
def check(self, value): |
|
result = self.base_constraint.check(value) |
|
if result.dim() < self.reinterpreted_batch_ndims: |
|
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims |
|
raise ValueError( |
|
f"Expected value.dim() >= {expected} but got {value.dim()}" |
|
) |
|
result = result.reshape( |
|
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) |
|
) |
|
result = result.all(-1) |
|
return result |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})" |
|
|
|
|
|
class _Boolean(Constraint): |
|
""" |
|
Constrain to the two values `{0, 1}`. |
|
""" |
|
|
|
is_discrete = True |
|
|
|
def check(self, value): |
|
return (value == 0) | (value == 1) |
|
|
|
|
|
class _OneHot(Constraint): |
|
""" |
|
Constrain to one-hot vectors. |
|
""" |
|
|
|
is_discrete = True |
|
event_dim = 1 |
|
|
|
def check(self, value): |
|
is_boolean = (value == 0) | (value == 1) |
|
is_normalized = value.sum(-1).eq(1) |
|
return is_boolean.all(-1) & is_normalized |
|
|
|
|
|
class _IntegerInterval(Constraint): |
|
""" |
|
Constrain to an integer interval `[lower_bound, upper_bound]`. |
|
""" |
|
|
|
is_discrete = True |
|
|
|
def __init__(self, lower_bound, upper_bound): |
|
self.lower_bound = lower_bound |
|
self.upper_bound = upper_bound |
|
super().__init__() |
|
|
|
def check(self, value): |
|
return ( |
|
(value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) |
|
) |
|
|
|
def __repr__(self): |
|
fmt_string = self.__class__.__name__[1:] |
|
fmt_string += ( |
|
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" |
|
) |
|
return fmt_string |
|
|
|
|
|
class _IntegerLessThan(Constraint): |
|
""" |
|
Constrain to an integer interval `(-inf, upper_bound]`. |
|
""" |
|
|
|
is_discrete = True |
|
|
|
def __init__(self, upper_bound): |
|
self.upper_bound = upper_bound |
|
super().__init__() |
|
|
|
def check(self, value): |
|
return (value % 1 == 0) & (value <= self.upper_bound) |
|
|
|
def __repr__(self): |
|
fmt_string = self.__class__.__name__[1:] |
|
fmt_string += f"(upper_bound={self.upper_bound})" |
|
return fmt_string |
|
|
|
|
|
class _IntegerGreaterThan(Constraint): |
|
""" |
|
Constrain to an integer interval `[lower_bound, inf)`. |
|
""" |
|
|
|
is_discrete = True |
|
|
|
def __init__(self, lower_bound): |
|
self.lower_bound = lower_bound |
|
super().__init__() |
|
|
|
def check(self, value): |
|
return (value % 1 == 0) & (value >= self.lower_bound) |
|
|
|
def __repr__(self): |
|
fmt_string = self.__class__.__name__[1:] |
|
fmt_string += f"(lower_bound={self.lower_bound})" |
|
return fmt_string |
|
|
|
|
|
class _Real(Constraint): |
|
""" |
|
Trivially constrain to the extended real line `[-inf, inf]`. |
|
""" |
|
|
|
def check(self, value): |
|
return value == value |
|
|
|
|
|
class _GreaterThan(Constraint): |
|
""" |
|
Constrain to a real half line `(lower_bound, inf]`. |
|
""" |
|
|
|
def __init__(self, lower_bound): |
|
self.lower_bound = lower_bound |
|
super().__init__() |
|
|
|
def check(self, value): |
|
return self.lower_bound < value |
|
|
|
def __repr__(self): |
|
fmt_string = self.__class__.__name__[1:] |
|
fmt_string += f"(lower_bound={self.lower_bound})" |
|
return fmt_string |
|
|
|
|
|
class _GreaterThanEq(Constraint): |
|
""" |
|
Constrain to a real half line `[lower_bound, inf)`. |
|
""" |
|
|
|
def __init__(self, lower_bound): |
|
self.lower_bound = lower_bound |
|
super().__init__() |
|
|
|
def check(self, value): |
|
return self.lower_bound <= value |
|
|
|
def __repr__(self): |
|
fmt_string = self.__class__.__name__[1:] |
|
fmt_string += f"(lower_bound={self.lower_bound})" |
|
return fmt_string |
|
|
|
|
|
class _LessThan(Constraint): |
|
""" |
|
Constrain to a real half line `[-inf, upper_bound)`. |
|
""" |
|
|
|
def __init__(self, upper_bound): |
|
self.upper_bound = upper_bound |
|
super().__init__() |
|
|
|
def check(self, value): |
|
return value < self.upper_bound |
|
|
|
def __repr__(self): |
|
fmt_string = self.__class__.__name__[1:] |
|
fmt_string += f"(upper_bound={self.upper_bound})" |
|
return fmt_string |
|
|
|
|
|
class _Interval(Constraint): |
|
""" |
|
Constrain to a real interval `[lower_bound, upper_bound]`. |
|
""" |
|
|
|
def __init__(self, lower_bound, upper_bound): |
|
self.lower_bound = lower_bound |
|
self.upper_bound = upper_bound |
|
super().__init__() |
|
|
|
def check(self, value): |
|
return (self.lower_bound <= value) & (value <= self.upper_bound) |
|
|
|
def __repr__(self): |
|
fmt_string = self.__class__.__name__[1:] |
|
fmt_string += ( |
|
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" |
|
) |
|
return fmt_string |
|
|
|
|
|
class _HalfOpenInterval(Constraint): |
|
""" |
|
Constrain to a real interval `[lower_bound, upper_bound)`. |
|
""" |
|
|
|
def __init__(self, lower_bound, upper_bound): |
|
self.lower_bound = lower_bound |
|
self.upper_bound = upper_bound |
|
super().__init__() |
|
|
|
def check(self, value): |
|
return (self.lower_bound <= value) & (value < self.upper_bound) |
|
|
|
def __repr__(self): |
|
fmt_string = self.__class__.__name__[1:] |
|
fmt_string += ( |
|
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" |
|
) |
|
return fmt_string |
|
|
|
|
|
class _Simplex(Constraint): |
|
""" |
|
Constrain to the unit simplex in the innermost (rightmost) dimension. |
|
Specifically: `x >= 0` and `x.sum(-1) == 1`. |
|
""" |
|
|
|
event_dim = 1 |
|
|
|
def check(self, value): |
|
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) |
|
|
|
|
|
class _Multinomial(Constraint): |
|
""" |
|
Constrain to nonnegative integer values summing to at most an upper bound. |
|
|
|
Note due to limitations of the Multinomial distribution, this currently |
|
checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future |
|
this may be strengthened to ``value.sum(-1) == upper_bound``. |
|
""" |
|
|
|
is_discrete = True |
|
event_dim = 1 |
|
|
|
def __init__(self, upper_bound): |
|
self.upper_bound = upper_bound |
|
|
|
def check(self, x): |
|
return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) |
|
|
|
|
|
class _LowerTriangular(Constraint): |
|
""" |
|
Constrain to lower-triangular square matrices. |
|
""" |
|
|
|
event_dim = 2 |
|
|
|
def check(self, value): |
|
value_tril = value.tril() |
|
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] |
|
|
|
|
|
class _LowerCholesky(Constraint): |
|
""" |
|
Constrain to lower-triangular square matrices with positive diagonals. |
|
""" |
|
|
|
event_dim = 2 |
|
|
|
def check(self, value): |
|
value_tril = value.tril() |
|
lower_triangular = ( |
|
(value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] |
|
) |
|
|
|
positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] |
|
return lower_triangular & positive_diagonal |
|
|
|
|
|
class _CorrCholesky(Constraint): |
|
""" |
|
Constrain to lower-triangular square matrices with positive diagonals and each |
|
row vector being of unit length. |
|
""" |
|
|
|
event_dim = 2 |
|
|
|
def check(self, value): |
|
tol = ( |
|
torch.finfo(value.dtype).eps * value.size(-1) * 10 |
|
) |
|
row_norm = torch.linalg.norm(value.detach(), dim=-1) |
|
unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1) |
|
return _LowerCholesky().check(value) & unit_row_norm |
|
|
|
|
|
class _Square(Constraint): |
|
""" |
|
Constrain to square matrices. |
|
""" |
|
|
|
event_dim = 2 |
|
|
|
def check(self, value): |
|
return torch.full( |
|
size=value.shape[:-2], |
|
fill_value=(value.shape[-2] == value.shape[-1]), |
|
dtype=torch.bool, |
|
device=value.device, |
|
) |
|
|
|
|
|
class _Symmetric(_Square): |
|
""" |
|
Constrain to Symmetric square matrices. |
|
""" |
|
|
|
def check(self, value): |
|
square_check = super().check(value) |
|
if not square_check.all(): |
|
return square_check |
|
return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1) |
|
|
|
|
|
class _PositiveSemidefinite(_Symmetric): |
|
""" |
|
Constrain to positive-semidefinite matrices. |
|
""" |
|
|
|
def check(self, value): |
|
sym_check = super().check(value) |
|
if not sym_check.all(): |
|
return sym_check |
|
return torch.linalg.eigvalsh(value).ge(0).all(-1) |
|
|
|
|
|
class _PositiveDefinite(_Symmetric): |
|
""" |
|
Constrain to positive-definite matrices. |
|
""" |
|
|
|
def check(self, value): |
|
sym_check = super().check(value) |
|
if not sym_check.all(): |
|
return sym_check |
|
return torch.linalg.cholesky_ex(value).info.eq(0) |
|
|
|
|
|
class _Cat(Constraint): |
|
""" |
|
Constraint functor that applies a sequence of constraints |
|
`cseq` at the submatrices at dimension `dim`, |
|
each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`. |
|
""" |
|
|
|
def __init__(self, cseq, dim=0, lengths=None): |
|
assert all(isinstance(c, Constraint) for c in cseq) |
|
self.cseq = list(cseq) |
|
if lengths is None: |
|
lengths = [1] * len(self.cseq) |
|
self.lengths = list(lengths) |
|
assert len(self.lengths) == len(self.cseq) |
|
self.dim = dim |
|
super().__init__() |
|
|
|
@property |
|
def is_discrete(self) -> bool: |
|
return any(c.is_discrete for c in self.cseq) |
|
|
|
@property |
|
def event_dim(self) -> int: |
|
return max(c.event_dim for c in self.cseq) |
|
|
|
def check(self, value): |
|
assert -value.dim() <= self.dim < value.dim() |
|
checks = [] |
|
start = 0 |
|
for constr, length in zip(self.cseq, self.lengths): |
|
v = value.narrow(self.dim, start, length) |
|
checks.append(constr.check(v)) |
|
start = start + length |
|
return torch.cat(checks, self.dim) |
|
|
|
|
|
class _Stack(Constraint): |
|
""" |
|
Constraint functor that applies a sequence of constraints |
|
`cseq` at the submatrices at dimension `dim`, |
|
in a way compatible with :func:`torch.stack`. |
|
""" |
|
|
|
def __init__(self, cseq, dim=0): |
|
assert all(isinstance(c, Constraint) for c in cseq) |
|
self.cseq = list(cseq) |
|
self.dim = dim |
|
super().__init__() |
|
|
|
@property |
|
def is_discrete(self) -> bool: |
|
return any(c.is_discrete for c in self.cseq) |
|
|
|
@property |
|
def event_dim(self) -> int: |
|
dim = max(c.event_dim for c in self.cseq) |
|
if self.dim + dim < 0: |
|
dim += 1 |
|
return dim |
|
|
|
def check(self, value): |
|
assert -value.dim() <= self.dim < value.dim() |
|
vs = [value.select(self.dim, i) for i in range(value.size(self.dim))] |
|
return torch.stack( |
|
[constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim |
|
) |
|
|
|
|
|
|
|
dependent = _Dependent() |
|
dependent_property = _DependentProperty |
|
independent = _IndependentConstraint |
|
boolean = _Boolean() |
|
one_hot = _OneHot() |
|
nonnegative_integer = _IntegerGreaterThan(0) |
|
positive_integer = _IntegerGreaterThan(1) |
|
integer_interval = _IntegerInterval |
|
real = _Real() |
|
real_vector = independent(real, 1) |
|
positive = _GreaterThan(0.0) |
|
nonnegative = _GreaterThanEq(0.0) |
|
greater_than = _GreaterThan |
|
greater_than_eq = _GreaterThanEq |
|
less_than = _LessThan |
|
multinomial = _Multinomial |
|
unit_interval = _Interval(0.0, 1.0) |
|
interval = _Interval |
|
half_open_interval = _HalfOpenInterval |
|
simplex = _Simplex() |
|
lower_triangular = _LowerTriangular() |
|
lower_cholesky = _LowerCholesky() |
|
corr_cholesky = _CorrCholesky() |
|
square = _Square() |
|
symmetric = _Symmetric() |
|
positive_semidefinite = _PositiveSemidefinite() |
|
positive_definite = _PositiveDefinite() |
|
cat = _Cat |
|
stack = _Stack |
|
|