# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import re | |
from typing import List, Tuple | |
import torch | |
__all__: List[str] = ["torch_version"] | |
def torch_version(version: str = torch.__version__) -> Tuple[int, ...]: | |
numbering = re.search(r"^(\d+).(\d+).(\d+)([^\+]*)(\+\S*)?$", version) | |
if not numbering: | |
return tuple() | |
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`, | |
if numbering.group(4): | |
# Two options here: | |
# - either skip this version (minor number check is not relevant) | |
# - or check that our codebase is not broken by this ongoing development. | |
# Assuming that we're interested in the second use-case more than the first, | |
# return the pre-release or dev numbering | |
logging.warning(f"Pytorch pre-release version {version} - assuming intent to test it") | |
return tuple(int(numbering.group(n)) for n in range(1, 4)) | |