jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing_extensions import override
from lightning_fabric.utilities.exceptions import MisconfigurationException
from lightning_fabric.utilities.registry import _register_classes
class _AcceleratorRegistry(dict):
"""This class is a Registry that stores information about the Accelerators.
The Accelerators are mapped to strings. These strings are names that identify
an accelerator, e.g., "gpu". It also returns Optional description and
parameters to initialize the Accelerator, which were defined during the
registration.
The motivation for having a AcceleratorRegistry is to make it convenient
for the Users to try different accelerators by passing mapped aliases
to the accelerator flag to the Trainer.
Example::
@AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True)
class SOTAAccelerator(Accelerator):
def __init__(self, a, b):
...
or
AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True)
"""
def register(
self,
name: str,
accelerator: Optional[Callable] = None,
description: str = "",
override: bool = False,
**init_params: Any,
) -> Callable:
"""Registers a accelerator mapped to a name and with required metadata.
Args:
name : the name that identifies a accelerator, e.g. "gpu"
accelerator : accelerator class
description : accelerator description
override : overrides the registered accelerator, if True
init_params: parameters to initialize the accelerator
"""
if not (name is None or isinstance(name, str)):
raise TypeError(f"`name` must be a str, found {name}")
if name in self and not override:
raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.")
data: dict[str, Any] = {}
data["description"] = description
data["init_params"] = init_params
def do_register(name: str, accelerator: Callable) -> Callable:
data["accelerator"] = accelerator
data["accelerator_name"] = name
self[name] = data
return accelerator
if accelerator is not None:
return do_register(name, accelerator)
return do_register
@override
def get(self, name: str, default: Optional[Any] = None) -> Any:
"""Calls the registered accelerator with the required parameters and returns the accelerator object.
Args:
name (str): the name that identifies a accelerator, e.g. "gpu"
"""
if name in self:
data = self[name]
return data["accelerator"](**data["init_params"])
if default is not None:
return default
err_msg = "'{}' not found in registry. Available names: {}"
available_names = self.available_accelerators()
raise KeyError(err_msg.format(name, available_names))
def remove(self, name: str) -> None:
"""Removes the registered accelerator by name."""
self.pop(name)
def available_accelerators(self) -> set[str]:
"""Returns a set of registered accelerators."""
return set(self.keys())
def __str__(self) -> str:
return "Registered Accelerators: {}".format(", ".join(self.available_accelerators()))
def call_register_accelerators(registry: _AcceleratorRegistry, base_module: str) -> None: # pragma: no-cover
"""Legacy.
Do not use.
"""
import importlib
module = importlib.import_module(base_module)
from lightning_fabric.accelerators.accelerator import Accelerator
_register_classes(registry, "register_accelerators", module, Accelerator)