|
from collections.abc import Iterable |
|
from typing import Any |
|
|
|
from torch._vendor.packaging.version import InvalidVersion, Version |
|
from torch.version import __version__ as internal_version |
|
|
|
|
|
__all__ = ["TorchVersion"] |
|
|
|
|
|
class TorchVersion(str): |
|
"""A string with magic powers to compare to both Version and iterables! |
|
Prior to 1.10.0 torch.__version__ was stored as a str and so many did |
|
comparisons against torch.__version__ as if it were a str. In order to not |
|
break them we have TorchVersion which masquerades as a str while also |
|
having the ability to compare against both packaging.version.Version as |
|
well as tuples of values, eg. (1, 2, 1) |
|
Examples: |
|
Comparing a TorchVersion object to a Version object |
|
TorchVersion('1.10.0a') > Version('1.10.0a') |
|
Comparing a TorchVersion object to a Tuple object |
|
TorchVersion('1.10.0a') > (1, 2) # 1.2 |
|
TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 |
|
Comparing a TorchVersion object against a string |
|
TorchVersion('1.10.0a') > '1.2' |
|
TorchVersion('1.10.0a') > '1.2.1' |
|
""" |
|
|
|
__slots__ = () |
|
|
|
|
|
def _convert_to_version(self, inp: Any) -> Any: |
|
if isinstance(inp, Version): |
|
return inp |
|
elif isinstance(inp, str): |
|
return Version(inp) |
|
elif isinstance(inp, Iterable): |
|
|
|
|
|
|
|
|
|
|
|
|
|
return Version(".".join(str(item) for item in inp)) |
|
else: |
|
raise InvalidVersion(inp) |
|
|
|
def _cmp_wrapper(self, cmp: Any, method: str) -> bool: |
|
try: |
|
return getattr(Version(self), method)(self._convert_to_version(cmp)) |
|
except BaseException as e: |
|
if not isinstance(e, InvalidVersion): |
|
raise |
|
|
|
|
|
return getattr(super(), method)(cmp) |
|
|
|
|
|
for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]: |
|
setattr( |
|
TorchVersion, |
|
cmp_method, |
|
lambda x, y, method=cmp_method: x._cmp_wrapper(y, method), |
|
) |
|
|
|
__version__ = TorchVersion(internal_version) |
|
|