Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,547 Bytes
42f2c22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
"""
Decorators.
"""
import functools
import threading
import time
from typing import Callable
import torch
from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank
from common.logger import get_logger
logger = get_logger(__name__)
def log_on_entry(func: Callable) -> Callable:
"""
Functions with this decorator will log the function name at entry.
When using multiple decorators, this must be applied innermost to properly capture the name.
"""
def log_on_entry_wrapper(*args, **kwargs):
logger.info(f"Entering {func.__name__}")
return func(*args, **kwargs)
return log_on_entry_wrapper
def barrier_on_entry(func: Callable) -> Callable:
"""
Functions with this decorator will start executing when all ranks are ready to enter.
"""
def barrier_on_entry_wrapper(*args, **kwargs):
barrier_if_distributed()
return func(*args, **kwargs)
return barrier_on_entry_wrapper
def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable:
"""
Helper function for local_rank_zero_only and global_rank_zero_only.
"""
def conditional_execute_wrapper(*args, **kwargs):
# Only execute if needed.
result = func(*args, **kwargs) if execute else None
# All GPUs must wait.
barrier_if_distributed()
# Return results.
return result
return conditional_execute_wrapper
def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable:
"""
Helper function for some functions with special constraints,
especially functions called by other global_rank_zero_only / local_rank_zero_only ones,
in case they are wrongly invoked in other scenarios.
"""
def asserted_execute_wrapper(*args, **kwargs):
assert condition, err_msg
result = func(*args, **kwargs)
return result
return asserted_execute_wrapper
def local_rank_zero_only(func: Callable) -> Callable:
"""
Functions with this decorator will only execute on local rank zero.
"""
return _conditional_execute_wrapper_factory(get_local_rank() == 0, func)
def global_rank_zero_only(func: Callable) -> Callable:
"""
Functions with this decorator will only execute on global rank zero.
"""
return _conditional_execute_wrapper_factory(get_global_rank() == 0, func)
def assert_only_global_rank_zero(func: Callable) -> Callable:
"""
Functions with this decorator are only accessible to processes with global rank zero.
"""
return _asserted_wrapper_factory(
get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0"
)
def assert_only_local_rank_zero(func: Callable) -> Callable:
"""
Functions with this decorator are only accessible to processes with local rank zero.
"""
return _asserted_wrapper_factory(
get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0"
)
def new_thread(func: Callable) -> Callable:
"""
Functions with this decorator will run in a new thread.
The function will return the thread, which can be joined to wait for completion.
"""
def new_thread_wrapper(*args, **kwargs):
thread = threading.Thread(target=func, args=args, kwargs=kwargs)
thread.start()
return thread
return new_thread_wrapper
def log_runtime(func: Callable) -> Callable:
"""
Functions with this decorator will logging the runtime.
"""
@functools.wraps(func)
def wrapped(*args, **kwargs):
torch.distributed.barrier()
start = time.perf_counter()
result = func(*args, **kwargs)
torch.distributed.barrier()
logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.")
return result
return wrapped
|