|
import traceback as tb |
|
from typing import Any |
|
|
|
|
|
WRAPPED_EXCEPTION = tuple[BaseException, tb.StackSummary] |
|
|
|
__all__ = ["CheckpointException"] |
|
|
|
|
|
def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: |
|
return (exc, tb.extract_tb(exc.__traceback__)) |
|
|
|
|
|
def _is_wrapped_exception(obj: Any) -> bool: |
|
if not isinstance(obj, tuple): |
|
return False |
|
if len(obj) != 2: |
|
return False |
|
return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary) |
|
|
|
|
|
class CheckpointException(BaseException): |
|
"""Exception raised if failure was detected as part of a checkpoint load or save.""" |
|
|
|
def __init__(self, msg: str, failures: dict[int, WRAPPED_EXCEPTION]): |
|
super().__init__(msg, failures) |
|
self._failures = failures |
|
|
|
@property |
|
def failures(self) -> dict[int, WRAPPED_EXCEPTION]: |
|
"""Return a dictionary mapping node ranks to their associated exceptions in case of failure.""" |
|
return self._failures |
|
|
|
def __str__(self) -> str: |
|
str = f"CheckpointException ranks:{self._failures.keys()}\n" |
|
for rank, exc_pair in self._failures.items(): |
|
exc, trace = exc_pair |
|
str += f"Traceback (most recent call last): (RANK {rank})\n" |
|
if trace is not None: |
|
str += "".join(tb.format_list(trace)) |
|
str += "".join(tb.format_exception_only(type(exc), value=exc)) |
|
return str |
|
|