File size: 4,547 Bytes
42f2c22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# //     http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.

"""
Decorators.
"""

import functools
import threading
import time
from typing import Callable
import torch

from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank
from common.logger import get_logger

logger = get_logger(__name__)


def log_on_entry(func: Callable) -> Callable:
    """
    Functions with this decorator will log the function name at entry.
    When using multiple decorators, this must be applied innermost to properly capture the name.
    """

    def log_on_entry_wrapper(*args, **kwargs):
        logger.info(f"Entering {func.__name__}")
        return func(*args, **kwargs)

    return log_on_entry_wrapper


def barrier_on_entry(func: Callable) -> Callable:
    """
    Functions with this decorator will start executing when all ranks are ready to enter.
    """

    def barrier_on_entry_wrapper(*args, **kwargs):
        barrier_if_distributed()
        return func(*args, **kwargs)

    return barrier_on_entry_wrapper


def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable:
    """
    Helper function for local_rank_zero_only and global_rank_zero_only.
    """

    def conditional_execute_wrapper(*args, **kwargs):
        # Only execute if needed.
        result = func(*args, **kwargs) if execute else None
        # All GPUs must wait.
        barrier_if_distributed()
        # Return results.
        return result

    return conditional_execute_wrapper


def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable:
    """
    Helper function for some functions with special constraints,
    especially functions called by other global_rank_zero_only / local_rank_zero_only ones,
    in case they are wrongly invoked in other scenarios.
    """

    def asserted_execute_wrapper(*args, **kwargs):
        assert condition, err_msg
        result = func(*args, **kwargs)
        return result

    return asserted_execute_wrapper


def local_rank_zero_only(func: Callable) -> Callable:
    """
    Functions with this decorator will only execute on local rank zero.
    """
    return _conditional_execute_wrapper_factory(get_local_rank() == 0, func)


def global_rank_zero_only(func: Callable) -> Callable:
    """
    Functions with this decorator will only execute on global rank zero.
    """
    return _conditional_execute_wrapper_factory(get_global_rank() == 0, func)


def assert_only_global_rank_zero(func: Callable) -> Callable:
    """
    Functions with this decorator are only accessible to processes with global rank zero.
    """
    return _asserted_wrapper_factory(
        get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0"
    )


def assert_only_local_rank_zero(func: Callable) -> Callable:
    """
    Functions with this decorator are only accessible to processes with local rank zero.
    """
    return _asserted_wrapper_factory(
        get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0"
    )


def new_thread(func: Callable) -> Callable:
    """
    Functions with this decorator will run in a new thread.
    The function will return the thread, which can be joined to wait for completion.
    """

    def new_thread_wrapper(*args, **kwargs):
        thread = threading.Thread(target=func, args=args, kwargs=kwargs)
        thread.start()
        return thread

    return new_thread_wrapper


def log_runtime(func: Callable) -> Callable:
    """
    Functions with this decorator will logging the runtime.
    """

    @functools.wraps(func)
    def wrapped(*args, **kwargs):
        torch.distributed.barrier()
        start = time.perf_counter()
        result = func(*args, **kwargs)
        torch.distributed.barrier()
        logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.")
        return result

    return wrapped