jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright Lightning AI.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
from copy import deepcopy
from typing import Any, Callable, Optional, Union
from packaging.version import Version
import pytorch_lightning as pl
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from pytorch_lightning.callbacks import Checkpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies.launchers import _SubprocessScriptLauncher
from pytorch_lightning.trainer.connectors.signal_connector import _get_sigkill_signal
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.utilities.exceptions import _TunerExitException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
log = logging.getLogger(__name__)
def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
r"""Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
as all errors should funnel through them.
Args:
trainer_fn: one of (fit, validate, test, predict)
*args: positional arguments to be passed to the `trainer_fn`
**kwargs: keyword arguments to be passed to `trainer_fn`
"""
try:
if trainer.strategy.launcher is not None:
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
return trainer_fn(*args, **kwargs)
except _TunerExitException:
_call_teardown_hook(trainer)
trainer._teardown()
trainer.state.status = TrainerStatus.FINISHED
trainer.state.stage = None
except KeyboardInterrupt as exception:
rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...")
# user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown
signal.signal(signal.SIGINT, signal.SIG_IGN)
_interrupt(trainer, exception)
trainer._teardown()
launcher = trainer.strategy.launcher
if isinstance(launcher, _SubprocessScriptLauncher):
launcher.kill(_get_sigkill_signal())
exit(1)
except BaseException as exception:
_interrupt(trainer, exception)
trainer._teardown()
# teardown might access the stage so we reset it after
trainer.state.stage = None
raise
def _interrupt(trainer: "pl.Trainer", exception: BaseException) -> None:
trainer.state.status = TrainerStatus.INTERRUPTED
_call_callback_hooks(trainer, "on_exception", exception)
if trainer.datamodule is not None:
_call_lightning_datamodule_hook(trainer, "on_exception", exception)
trainer.strategy.on_exception(exception)
for logger in trainer.loggers:
logger.finalize("failed")
def _call_setup_hook(trainer: "pl.Trainer") -> None:
assert trainer.state.fn is not None
fn = trainer.state.fn
# It is too early to move the model to the device, but we fake the `LightningModule.device` property
# so the user can access it in the `LightningModule.setup` hook
for module in trainer.lightning_module.modules():
if isinstance(module, _DeviceDtypeModuleMixin):
module._device = trainer.strategy.root_device
# wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb:
# https://github.com/wandb/wandb/issues/1782#issuecomment-779161203
loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger))
# Trigger lazy creation of experiment in loggers so loggers have their metadata available
for logger in loggers:
if hasattr(logger, "experiment"):
_ = logger.experiment
trainer.strategy.barrier("pre_setup")
if trainer.datamodule is not None:
_call_lightning_datamodule_hook(trainer, "setup", stage=fn)
_call_callback_hooks(trainer, "setup", stage=fn)
_call_lightning_module_hook(trainer, "setup", stage=fn)
trainer.strategy.barrier("post_setup")
def _call_configure_model(trainer: "pl.Trainer") -> None:
# legacy hook
if is_overridden("configure_sharded_model", trainer.lightning_module):
with trainer.strategy.model_sharded_context():
_call_lightning_module_hook(trainer, "configure_sharded_model")
# we don't normally check for this before calling the hook. it is done here to avoid instantiating the context
# managers
if is_overridden("configure_model", trainer.lightning_module):
with (
trainer.strategy.tensor_init_context(),
trainer.strategy.model_sharded_context(),
trainer.precision_plugin.module_init_context(),
):
_call_lightning_module_hook(trainer, "configure_model")
def _call_teardown_hook(trainer: "pl.Trainer") -> None:
assert trainer.state.fn is not None
fn = trainer.state.fn
if trainer.datamodule is not None:
_call_lightning_datamodule_hook(trainer, "teardown", stage=fn)
_call_callback_hooks(trainer, "teardown", stage=fn)
_call_lightning_module_hook(trainer, "teardown", stage=fn)
trainer.lightning_module._current_fx_name = None
# these could have become stale if metrics are defined in `setup`
trainer.lightning_module._metric_attributes = None
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu kill loggers.
for logger in trainer.loggers:
logger.finalize("success")
# summarize profile results
trainer.profiler.describe()
def _call_lightning_module_hook(
trainer: "pl.Trainer",
hook_name: str,
*args: Any,
pl_module: Optional["pl.LightningModule"] = None,
**kwargs: Any,
) -> Any:
log.debug(f"{trainer.__class__.__name__}: calling lightning module hook: {hook_name}")
pl_module = pl_module or trainer.lightning_module
if pl_module is None:
raise TypeError("No `LightningModule` is available to call hooks on.")
fn = getattr(pl_module, hook_name)
if not callable(fn):
return None
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
output = fn(*args, **kwargs)
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
return output
def _call_lightning_datamodule_hook(
trainer: "pl.Trainer",
hook_name: str,
*args: Any,
**kwargs: Any,
) -> Any:
log.debug(f"{trainer.__class__.__name__}: calling lightning datamodule hook: {hook_name}")
if trainer.datamodule is None:
raise TypeError("No `LightningDataModule` is available to call hooks on.")
fn = getattr(trainer.datamodule, hook_name)
if callable(fn):
with trainer.profiler.profile(f"[LightningDataModule]{trainer.datamodule.__class__.__name__}.{hook_name}"):
return fn(*args, **kwargs)
return None
def _call_callback_hooks(
trainer: "pl.Trainer",
hook_name: str,
*args: Any,
monitoring_callbacks: Optional[bool] = None,
**kwargs: Any,
) -> None:
log.debug(f"{trainer.__class__.__name__}: calling callback hook: {hook_name}")
pl_module = trainer.lightning_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
callbacks = trainer.callbacks
if monitoring_callbacks is True:
# the list of "monitoring callbacks" is hard-coded to these two. we could add an API to define this
callbacks = [cb for cb in callbacks if isinstance(cb, (EarlyStopping, Checkpoint))]
elif monitoring_callbacks is False:
callbacks = [cb for cb in callbacks if not isinstance(cb, (EarlyStopping, Checkpoint))]
for callback in callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
with trainer.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
fn(trainer, trainer.lightning_module, *args, **kwargs)
if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
def _call_callbacks_state_dict(trainer: "pl.Trainer") -> dict[str, dict]:
"""Called when saving a model checkpoint, calls and returns every callback's `state_dict`, keyed by
`Callback.state_key`."""
callback_state_dicts = {}
for callback in trainer.callbacks:
state_dict = callback.state_dict()
if state_dict:
callback_state_dicts[callback.state_key] = state_dict
return callback_state_dicts
def _call_callbacks_on_save_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None:
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook."""
pl_module = trainer.lightning_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = "on_save_checkpoint"
for callback in trainer.callbacks:
with trainer.profiler.profile(f"[Callback]{callback.state_key}.on_save_checkpoint"):
callback.on_save_checkpoint(trainer, trainer.lightning_module, checkpoint)
if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
def _call_callbacks_on_load_checkpoint(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None:
"""Called when loading a model checkpoint.
Calls every callback's `on_load_checkpoint` hook. We have a dedicated function for this rather than using
`_call_callback_hooks` because we have special logic for getting callback_states.
"""
pl_module = trainer.lightning_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = "on_load_checkpoint"
callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks")
if callback_states is None:
return
is_legacy_ckpt = Version(checkpoint["pytorch-lightning_version"]) < Version("1.5.0dev")
current_callbacks_keys = {cb._legacy_state_key if is_legacy_ckpt else cb.state_key for cb in trainer.callbacks}
difference = callback_states.keys() - current_callbacks_keys
if difference:
rank_zero_warn(
"Be aware that when using `ckpt_path`,"
" callbacks used to create the checkpoint need to be provided during `Trainer` instantiation."
f" Please add the following callbacks: {list(difference)}.",
)
for callback in trainer.callbacks:
with trainer.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
callback.on_load_checkpoint(trainer, trainer.lightning_module, checkpoint)
if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
def _call_callbacks_load_state_dict(trainer: "pl.Trainer", checkpoint: dict[str, Any]) -> None:
"""Called when loading a model checkpoint, calls every callback's `load_state_dict`."""
callback_states: Optional[dict[Union[type, str], dict]] = checkpoint.get("callbacks")
if callback_states is None:
return
for callback in trainer.callbacks:
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
if state:
state = deepcopy(state)
callback.load_state_dict(state)
def _call_strategy_hook(
trainer: "pl.Trainer",
hook_name: str,
*args: Any,
**kwargs: Any,
) -> Any:
log.debug(f"{trainer.__class__.__name__}: calling strategy hook: {hook_name}")
pl_module = trainer.lightning_module
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
fn = getattr(trainer.strategy, hook_name)
if not callable(fn):
return None
with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
output = fn(*args, **kwargs)
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
return output