File size: 9,883 Bytes
e0be88b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import unittest
from pathlib import Path
from typing import Callable

import pytest

from transformers.utils.import_utils import (
    Backend,
    VersionComparison,
    define_import_structure,
    spread_import_structure,
)


import_structures = Path(__file__).parent / "import_structures"


def fetch__all__(file_content):
    """
    Returns the content of the __all__ variable in the file content.
    Returns None if not defined, otherwise returns a list of strings.
    """
    lines = file_content.split("\n")
    for line_index in range(len(lines)):
        line = lines[line_index]
        if line.startswith("__all__ = "):
            # __all__ is defined on a single line
            if line.endswith("]"):
                return [obj.strip("\"' ") for obj in line.split("=")[1].strip(" []").split(",")]

            # __all__ is defined on multiple lines
            else:
                _all = []
                for __all__line_index in range(line_index + 1, len(lines)):
                    if lines[__all__line_index].strip() == "]":
                        return _all
                    else:
                        _all.append(lines[__all__line_index].strip("\"', "))


class TestImportStructures(unittest.TestCase):
    base_transformers_path = Path(__file__).parent.parent.parent
    models_path = base_transformers_path / "src" / "transformers" / "models"
    models_import_structure = spread_import_structure(define_import_structure(models_path))

    def test_definition(self):
        import_structure = define_import_structure(import_structures)
        valid_frozensets: dict[frozenset | frozenset[str], dict[str, set[str]]] = {
            frozenset(): {
                "import_structure_raw_register": {"A0", "A4", "a0"},
                "import_structure_register_with_comments": {"B0", "b0"},
            },
            frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}},
            frozenset({"torch"}): {
                "import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"}
            },
            frozenset({"tf", "torch"}): {
                "import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"},
                "import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"},
            },
            frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}},
            frozenset({"torch>2.5"}): {"import_structure_raw_register_with_versions": {"D1", "d1"}},
            frozenset({"torch<=2.5"}): {"import_structure_raw_register_with_versions": {"D2", "d2"}},
            frozenset({"torch<2.5"}): {"import_structure_raw_register_with_versions": {"D3", "d3"}},
            frozenset({"torch==2.5"}): {"import_structure_raw_register_with_versions": {"D4", "d4"}},
            frozenset({"torch!=2.5"}): {"import_structure_raw_register_with_versions": {"D5", "d5"}},
            frozenset({"torch>=2.5", "accelerate<0.20"}): {
                "import_structure_raw_register_with_versions": {"D6", "d6"}
            },
        }

        self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys()))
        for _frozenset in valid_frozensets.keys():
            self.assertTrue(_frozenset in import_structure)
            self.assertListEqual(list(import_structure[_frozenset].keys()), list(valid_frozensets[_frozenset].keys()))
            for module, objects in valid_frozensets[_frozenset].items():
                self.assertTrue(module in import_structure[_frozenset])
                self.assertSetEqual(objects, import_structure[_frozenset][module])

    def test_transformers_specific_model_import(self):
        """
        This test ensures that there is equivalence between what is written down in __all__ and what is
        written down with register().

        It doesn't test the backends attributed to register().
        """
        for architecture in os.listdir(self.models_path):
            if (
                os.path.isfile(self.models_path / architecture)
                or architecture.startswith("_")
                or architecture == "deprecated"
            ):
                continue

            with self.subTest(f"Testing arch {architecture}"):
                import_structure = define_import_structure(self.models_path / architecture)
                backend_agnostic_import_structure = {}
                for requirement, module_object_mapping in import_structure.items():
                    for module, objects in module_object_mapping.items():
                        if module not in backend_agnostic_import_structure:
                            backend_agnostic_import_structure[module] = []

                        backend_agnostic_import_structure[module].extend(objects)

                for module, objects in backend_agnostic_import_structure.items():
                    with open(self.models_path / architecture / f"{module}.py") as f:
                        content = f.read()
                        _all = fetch__all__(content)

                        if _all is None:
                            raise ValueError(f"{module} doesn't have __all__ defined.")

                        error_message = (
                            f"self.models_path / architecture / f'{module}.py doesn't seem to be defined correctly:\n"
                            f"Defined in __all__: {sorted(_all)}\nDefined with register: {sorted(objects)}"
                        )
                        self.assertListEqual(sorted(objects), sorted(_all), msg=error_message)

    def test_import_spread(self):
        """
        This test is specifically designed to test that varying levels of depth across import structures are
        respected.

        In this instance, frozensets are at respective depths of 1, 2 and 3, for example:
        - models.{frozensets}
        - models.albert.{frozensets}
        - models.deprecated.transfo_xl.{frozensets}
        """
        initial_import_structure = {
            frozenset(): {"dummy_non_model": {"DummyObject"}},
            "models": {
                frozenset(): {"dummy_config": {"DummyConfig"}},
                "albert": {
                    frozenset(): {"configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"}},
                    frozenset({"torch"}): {
                        "modeling_albert": {
                            "AlbertForMaskedLM",
                        }
                    },
                },
                "llama": {
                    frozenset(): {"configuration_llama": {"LlamaConfig"}},
                    frozenset({"torch"}): {
                        "modeling_llama": {
                            "LlamaForCausalLM",
                        }
                    },
                },
                "deprecated": {
                    "transfo_xl": {
                        frozenset({"torch"}): {
                            "modeling_transfo_xl": {
                                "TransfoXLModel",
                            }
                        },
                        frozenset(): {
                            "configuration_transfo_xl": {"TransfoXLConfig"},
                            "tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
                        },
                    },
                    "deta": {
                        frozenset({"torch"}): {
                            "modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"}
                        },
                        frozenset(): {"configuration_deta": {"DetaConfig"}},
                        frozenset({"vision"}): {"image_processing_deta": {"DetaImageProcessor"}},
                    },
                },
            },
        }

        ground_truth_spread_import_structure = {
            frozenset(): {
                "dummy_non_model": {"DummyObject"},
                "models.dummy_config": {"DummyConfig"},
                "models.albert.configuration_albert": {"AlbertConfig", "AlbertOnnxConfig"},
                "models.llama.configuration_llama": {"LlamaConfig"},
                "models.deprecated.transfo_xl.configuration_transfo_xl": {"TransfoXLConfig"},
                "models.deprecated.transfo_xl.tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"},
                "models.deprecated.deta.configuration_deta": {"DetaConfig"},
            },
            frozenset({"torch"}): {
                "models.albert.modeling_albert": {"AlbertForMaskedLM"},
                "models.llama.modeling_llama": {"LlamaForCausalLM"},
                "models.deprecated.transfo_xl.modeling_transfo_xl": {"TransfoXLModel"},
                "models.deprecated.deta.modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"},
            },
            frozenset({"vision"}): {"models.deprecated.deta.image_processing_deta": {"DetaImageProcessor"}},
        }

        newly_spread_import_structure = spread_import_structure(initial_import_structure)

        self.assertEqual(ground_truth_spread_import_structure, newly_spread_import_structure)


@pytest.mark.parametrize(
    "backend,package_name,version_comparison,version",
    [
        pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL.value, "2.5"),
        pytest.param(Backend("tf<=1"), "tf", VersionComparison.LESS_THAN_OR_EQUAL.value, "1"),
        pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL.value, "0.19.1"),
    ],
)
def test_backend_specification(backend: Backend, package_name: str, version_comparison: Callable, version: str):
    assert backend.package_name == package_name
    assert VersionComparison.from_string(backend.version_comparison) == version_comparison
    assert backend.version == version