|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
from pytorch_lightning.utilities.enums import LightningEnum |
|
|
|
|
|
class TrainerStatus(LightningEnum): |
|
"""Enum for the status of the :class:`~pytorch_lightning.trainer.trainer.Trainer`""" |
|
|
|
INITIALIZING = "initializing" |
|
RUNNING = "running" |
|
FINISHED = "finished" |
|
INTERRUPTED = "interrupted" |
|
|
|
@property |
|
def stopped(self) -> bool: |
|
return self in (self.FINISHED, self.INTERRUPTED) |
|
|
|
|
|
class TrainerFn(LightningEnum): |
|
"""Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer` such as |
|
:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and |
|
:meth:`~pytorch_lightning.trainer.trainer.Trainer.test`.""" |
|
|
|
FITTING = "fit" |
|
VALIDATING = "validate" |
|
TESTING = "test" |
|
PREDICTING = "predict" |
|
|
|
|
|
class RunningStage(LightningEnum): |
|
"""Enum for the current running stage. |
|
|
|
This stage complements :class:`TrainerFn` by specifying the current running stage for each function. |
|
More than one running stage value can be set while a :class:`TrainerFn` is running: |
|
|
|
- ``TrainerFn.FITTING`` - ``RunningStage.{SANITY_CHECKING,TRAINING,VALIDATING}`` |
|
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING`` |
|
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING`` |
|
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING`` |
|
|
|
""" |
|
|
|
TRAINING = "train" |
|
SANITY_CHECKING = "sanity_check" |
|
VALIDATING = "validate" |
|
TESTING = "test" |
|
PREDICTING = "predict" |
|
|
|
@property |
|
def evaluating(self) -> bool: |
|
return self in (self.VALIDATING, self.TESTING, self.SANITY_CHECKING) |
|
|
|
@property |
|
def dataloader_prefix(self) -> Optional[str]: |
|
if self in (self.VALIDATING, self.SANITY_CHECKING): |
|
return "val" |
|
return self.value |
|
|
|
|
|
@dataclass |
|
class TrainerState: |
|
"""Dataclass to encapsulate the current :class:`~pytorch_lightning.trainer.trainer.Trainer` state.""" |
|
|
|
status: TrainerStatus = TrainerStatus.INITIALIZING |
|
fn: Optional[TrainerFn] = None |
|
stage: Optional[RunningStage] = None |
|
|
|
@property |
|
def finished(self) -> bool: |
|
return self.status == TrainerStatus.FINISHED |
|
|
|
@property |
|
def stopped(self) -> bool: |
|
return self.status.stopped |
|
|