|
|
|
|
|
|
|
|
|
|
|
import functools |
|
import logging |
|
import sys |
|
import os |
|
import torch |
|
from deepspeed.utils.torch import required_torch_version |
|
|
|
log_levels = { |
|
"debug": logging.DEBUG, |
|
"info": logging.INFO, |
|
"warning": logging.WARNING, |
|
"error": logging.ERROR, |
|
"critical": logging.CRITICAL, |
|
} |
|
|
|
|
|
class LoggerFactory: |
|
|
|
@staticmethod |
|
def create_logger(name=None, level=logging.INFO): |
|
"""create a logger |
|
|
|
Args: |
|
name (str): name of the logger |
|
level: level of logger |
|
|
|
Raises: |
|
ValueError is name is None |
|
""" |
|
|
|
if name is None: |
|
raise ValueError("name for logger cannot be None") |
|
|
|
formatter = logging.Formatter("[%(asctime)s] [%(levelname)s] " |
|
"[%(filename)s:%(lineno)d:%(funcName)s] %(message)s") |
|
|
|
logger_ = logging.getLogger(name) |
|
logger_.setLevel(level) |
|
logger_.propagate = False |
|
ch = logging.StreamHandler(stream=sys.stdout) |
|
ch.setLevel(level) |
|
ch.setFormatter(formatter) |
|
logger_.addHandler(ch) |
|
if required_torch_version(min_version=2.6) and os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1": |
|
excluded_set = { |
|
item.strip() |
|
for item in os.getenv("LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE", "").split(",") |
|
} |
|
ignore_set = {'info', 'debug', 'error', 'warning', 'critical', 'exception', 'isEnabledFor'} - excluded_set |
|
for method in ignore_set: |
|
original_logger = getattr(logger_, method) |
|
torch._dynamo.config.ignore_logger_methods.add(original_logger) |
|
return logger_ |
|
|
|
|
|
logger = LoggerFactory.create_logger(name="DeepSpeed", level=logging.INFO) |
|
|
|
|
|
@functools.lru_cache(None) |
|
def warning_once(*args, **kwargs): |
|
""" |
|
This method is identical to `logger.warning()`, but will emit the warning with the same message only once |
|
|
|
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. |
|
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to |
|
another type of cache that includes the caller frame information in the hashing function. |
|
""" |
|
logger.warning(*args, **kwargs) |
|
|
|
|
|
logger.warning_once = warning_once |
|
|
|
|
|
def print_configuration(args, name): |
|
logger.info("{}:".format(name)) |
|
for arg in sorted(vars(args)): |
|
dots = "." * (29 - len(arg)) |
|
logger.info(" {} {} {}".format(arg, dots, getattr(args, arg))) |
|
|
|
|
|
def log_dist(message, ranks=None, level=logging.INFO): |
|
from deepspeed import comm as dist |
|
"""Log message when one of following condition meets |
|
|
|
+ not dist.is_initialized() |
|
+ dist.get_rank() in ranks if ranks is not None or ranks = [-1] |
|
|
|
Args: |
|
message (str) |
|
ranks (list) |
|
level (int) |
|
|
|
""" |
|
should_log = not dist.is_initialized() |
|
ranks = ranks or [] |
|
my_rank = dist.get_rank() if dist.is_initialized() else -1 |
|
if ranks and not should_log: |
|
should_log = ranks[0] == -1 |
|
should_log = should_log or (my_rank in set(ranks)) |
|
if should_log: |
|
final_message = "[Rank {}] {}".format(my_rank, message) |
|
logger.log(level, final_message) |
|
|
|
|
|
def print_json_dist(message, ranks=None, path=None): |
|
from deepspeed import comm as dist |
|
"""Print message when one of following condition meets |
|
|
|
+ not dist.is_initialized() |
|
+ dist.get_rank() in ranks if ranks is not None or ranks = [-1] |
|
|
|
Args: |
|
message (str) |
|
ranks (list) |
|
path (str) |
|
|
|
""" |
|
should_log = not dist.is_initialized() |
|
ranks = ranks or [] |
|
my_rank = dist.get_rank() if dist.is_initialized() else -1 |
|
if ranks and not should_log: |
|
should_log = ranks[0] == -1 |
|
should_log = should_log or (my_rank in set(ranks)) |
|
if should_log: |
|
message['rank'] = my_rank |
|
import json |
|
with open(path, 'w') as outfile: |
|
json.dump(message, outfile) |
|
os.fsync(outfile) |
|
|
|
|
|
def get_current_level(): |
|
""" |
|
Return logger's current log level |
|
""" |
|
return logger.getEffectiveLevel() |
|
|
|
|
|
def should_log_le(max_log_level_str): |
|
""" |
|
Args: |
|
max_log_level_str: maximum log level as a string |
|
|
|
Returns ``True`` if the current log_level is less or equal to the specified log level. Otherwise ``False``. |
|
|
|
Example: |
|
|
|
``should_log_le("info")`` will return ``True`` if the current log level is either ``logging.INFO`` or ``logging.DEBUG`` |
|
""" |
|
|
|
if not isinstance(max_log_level_str, str): |
|
raise ValueError(f"{max_log_level_str} is not a string") |
|
|
|
max_log_level_str = max_log_level_str.lower() |
|
if max_log_level_str not in log_levels: |
|
raise ValueError(f"{max_log_level_str} is not one of the logging levels") |
|
|
|
return get_current_level() <= log_levels[max_log_level_str] |
|
|