File size: 3,766 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import pytorch_lightning as pl
from pytorch_lightning.loops.progress import _BaseProgress


class _Loop:
    """Basic Loops interface."""

    def __init__(self, trainer: "pl.Trainer") -> None:
        self._restarting = False
        self._loaded_from_state_dict = False
        self.trainer = trainer

    @property
    def restarting(self) -> bool:
        """Whether the state of this loop was reloaded and it needs to restart."""
        return self._restarting

    @restarting.setter
    def restarting(self, restarting: bool) -> None:
        """Connects this loop's restarting value and its children."""
        self._restarting = restarting
        for loop in vars(self).values():
            if isinstance(loop, _Loop):
                loop.restarting = restarting

    def reset_restart_stage(self) -> None:
        pass

    def on_save_checkpoint(self) -> dict:
        """Called when saving a model checkpoint, use to persist loop state.

        Returns:
            The current loop state.

        """
        return {}

    def on_load_checkpoint(self, state_dict: dict) -> None:
        """Called when loading a model checkpoint, use to reload loop state."""

    def state_dict(self, destination: Optional[dict] = None, prefix: str = "") -> dict:
        """The state dict is determined by the state and progress of this loop and all its children.

        Args:
            destination: An existing dictionary to update with this loop's state. By default a new dictionary
                is returned.
            prefix: A prefix for each key in the state dictionary

        """
        if destination is None:
            destination = {}

        destination[prefix + "state_dict"] = self.on_save_checkpoint()

        for k, v in self.__dict__.items():
            key = prefix + k
            if isinstance(v, _BaseProgress):
                destination[key] = v.state_dict()
            elif isinstance(v, _Loop):
                v.state_dict(destination, key + ".")
        return destination

    def load_state_dict(
        self,
        state_dict: dict,
        prefix: str = "",
    ) -> None:
        """Loads the state of this loop and all its children."""
        self._load_from_state_dict(state_dict.copy(), prefix)
        for k, v in self.__dict__.items():
            if isinstance(v, _Loop):
                v.load_state_dict(state_dict.copy(), prefix + k + ".")
        self.restarting = True
        self._loaded_from_state_dict = True

    def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None:
        for k, v in self.__dict__.items():
            key = prefix + k
            if key not in state_dict:
                # compatibility with old checkpoints
                continue
            if isinstance(v, _BaseProgress):
                v.load_state_dict(state_dict[key])
        if prefix + "state_dict" in state_dict:  # compatibility with old checkpoints
            self.on_load_checkpoint(state_dict[prefix + "state_dict"])

    def on_iteration_done(self) -> None:
        self._restarting = False
        self._loaded_from_state_dict = False
        self.reset_restart_stage()