File size: 3,525 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Common utils for testing.
These functions allow testing only some frameworks, not all.
"""

import logging
import os
from functools import lru_cache
from typing import List, Tuple

from einops import _backends
import warnings

__author__ = "Alex Rogozhnikov"


# minimize noise in tests logging
logging.getLogger("tensorflow").disabled = True
logging.getLogger("matplotlib").disabled = True

FLOAT_REDUCTIONS = ("min", "max", "sum", "mean", "prod")  # not includes any/all


def find_names_of_all_frameworks() -> List[str]:
    backend_subclasses = []
    backends = _backends.AbstractBackend.__subclasses__()
    while backends:
        backend = backends.pop()
        backends += backend.__subclasses__()
        backend_subclasses.append(backend)
    return [b.framework_name for b in backend_subclasses]


ENVVAR_NAME = "EINOPS_TEST_BACKENDS"


def unparse_backends(backend_names: List[str]) -> Tuple[str, str]:
    _known_backends = find_names_of_all_frameworks()
    for backend_name in backend_names:
        if backend_name not in _known_backends:
            raise RuntimeError(f"Unknown framework: {backend_name}")
    return ENVVAR_NAME, ",".join(backend_names)


@lru_cache(maxsize=1)
def parse_backends_to_test() -> List[str]:
    if ENVVAR_NAME not in os.environ:
        raise RuntimeError(f"Testing frameworks were not specified, env var {ENVVAR_NAME} not set")
    parsed_backends = os.environ[ENVVAR_NAME].split(",")
    _known_backends = find_names_of_all_frameworks()
    for backend_name in parsed_backends:
        if backend_name not in _known_backends:
            raise RuntimeError(f"Unknown framework: {backend_name}")

    return parsed_backends


def is_backend_tested(backend: str) -> bool:
    """Used to skip test if corresponding backend is not tested"""
    if backend not in find_names_of_all_frameworks():
        raise RuntimeError(f"Unknown framework {backend}")
    return backend in parse_backends_to_test()


def collect_test_backends(symbolic=False, layers=False) -> List[_backends.AbstractBackend]:
    """
    :param symbolic: symbolic or imperative frameworks?
    :param layers: layers or operations?
    :return: list of backends satisfying set conditions
    """
    if not symbolic:
        if not layers:
            backend_types = [
                _backends.NumpyBackend,
                _backends.JaxBackend,
                _backends.TorchBackend,
                _backends.TensorflowBackend,
                _backends.OneFlowBackend,
                _backends.PaddleBackend,
                _backends.CupyBackend,
            ]
        else:
            backend_types = [
                _backends.TorchBackend,
                _backends.OneFlowBackend,
                _backends.PaddleBackend,
            ]
    else:
        if not layers:
            backend_types = [
                _backends.PyTensorBackend,
            ]
        else:
            backend_types = [
                _backends.TFKerasBackend,
            ]

    backend_names_to_test = parse_backends_to_test()
    result = []
    for backend_type in backend_types:
        if backend_type.framework_name not in backend_names_to_test:
            continue
        try:
            result.append(backend_type())
        except ImportError:
            # problem with backend installation fails a specific test function,
            # but will be skipped in all other test cases
            warnings.warn("backend could not be initialized for tests: {}".format(backend_type))
    return result