File size: 2,746 Bytes
9c4ca75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import os
from typing import List, Optional
import pytest
# from composer.utils import reproducibility
# Allowed options for pytest.mark.world_size()
WORLD_SIZE_OPTIONS = (1, 2)
# Enforce deterministic mode before any tests start.
# reproducibility.configure_deterministic_mode()
# TODO: allow plugind when deps resolved
# Add the path of any pytest fixture files you want to make global
pytest_plugins = [
# 'tests.fixtures.autouse',
'tests.fixtures.fixtures',
]
def _get_world_size(item: pytest.Item):
"""Returns the world_size of a test, defaults to 1."""
_default = pytest.mark.world_size(1).mark
return item.get_closest_marker('world_size', default=_default).args[0]
def _get_option(
config: pytest.Config,
name: str,
default: Optional[str] = None,
) -> str: # type: ignore
val = config.getoption(name)
if val is not None:
assert isinstance(val, str)
return val
val = config.getini(name)
if val == []:
val = None
if val is None:
if default is None:
pytest.fail(f'Config option {name} is not specified but is required',)
val = default
assert isinstance(val, str)
return val
def _add_option(
parser: pytest.Parser,
name: str,
help: str,
choices: Optional[list[str]] = None,
):
parser.addoption(
f'--{name}',
default=None,
type=str,
choices=choices,
help=help,
)
parser.addini(
name=name,
help=help,
type='string',
default=None,
)
def pytest_collection_modifyitems(
config: pytest.Config,
items: List[pytest.Item],
) -> None:
"""Filter tests by world_size (for multi-GPU tests)"""
world_size = int(os.environ.get('WORLD_SIZE', '1'))
print(f'world_size={world_size}')
conditions = [
lambda item: _get_world_size(item) == world_size,
]
# keep items that satisfy all conditions
remaining = []
deselected = []
for item in items:
if all(condition(item) for condition in conditions):
remaining.append(item)
else:
deselected.append(item)
if deselected:
config.hook.pytest_deselected(items=deselected)
items[:] = remaining
def pytest_addoption(parser: pytest.Parser) -> None:
_add_option(
parser,
'seed',
help="""\
Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked
before each test.""",
)
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
if exitstatus == 5:
session.exitstatus = 0 # Ignore no-test-ran errors
|