|
|
|
|
|
|
|
import os |
|
from typing import List, Optional |
|
|
|
import pytest |
|
|
|
|
|
|
|
WORLD_SIZE_OPTIONS = (1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pytest_plugins = [ |
|
|
|
'tests.fixtures.fixtures', |
|
] |
|
|
|
|
|
def _get_world_size(item: pytest.Item): |
|
"""Returns the world_size of a test, defaults to 1.""" |
|
_default = pytest.mark.world_size(1).mark |
|
return item.get_closest_marker('world_size', default=_default).args[0] |
|
|
|
|
|
def _get_option( |
|
config: pytest.Config, |
|
name: str, |
|
default: Optional[str] = None, |
|
) -> str: |
|
val = config.getoption(name) |
|
if val is not None: |
|
assert isinstance(val, str) |
|
return val |
|
val = config.getini(name) |
|
if val == []: |
|
val = None |
|
if val is None: |
|
if default is None: |
|
pytest.fail(f'Config option {name} is not specified but is required',) |
|
val = default |
|
assert isinstance(val, str) |
|
return val |
|
|
|
|
|
def _add_option( |
|
parser: pytest.Parser, |
|
name: str, |
|
help: str, |
|
choices: Optional[list[str]] = None, |
|
): |
|
parser.addoption( |
|
f'--{name}', |
|
default=None, |
|
type=str, |
|
choices=choices, |
|
help=help, |
|
) |
|
parser.addini( |
|
name=name, |
|
help=help, |
|
type='string', |
|
default=None, |
|
) |
|
|
|
|
|
def pytest_collection_modifyitems( |
|
config: pytest.Config, |
|
items: List[pytest.Item], |
|
) -> None: |
|
"""Filter tests by world_size (for multi-GPU tests)""" |
|
world_size = int(os.environ.get('WORLD_SIZE', '1')) |
|
print(f'world_size={world_size}') |
|
|
|
conditions = [ |
|
lambda item: _get_world_size(item) == world_size, |
|
] |
|
|
|
|
|
remaining = [] |
|
deselected = [] |
|
for item in items: |
|
if all(condition(item) for condition in conditions): |
|
remaining.append(item) |
|
else: |
|
deselected.append(item) |
|
|
|
if deselected: |
|
config.hook.pytest_deselected(items=deselected) |
|
items[:] = remaining |
|
|
|
|
|
def pytest_addoption(parser: pytest.Parser) -> None: |
|
_add_option( |
|
parser, |
|
'seed', |
|
help="""\ |
|
Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked |
|
before each test.""", |
|
) |
|
|
|
|
|
def pytest_sessionfinish(session: pytest.Session, exitstatus: int): |
|
if exitstatus == 5: |
|
session.exitstatus = 0 |
|
|