|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Optional |
|
|
|
from typing_extensions import override |
|
|
|
from lightning_fabric.utilities.exceptions import MisconfigurationException |
|
from lightning_fabric.utilities.registry import _register_classes |
|
|
|
|
|
class _AcceleratorRegistry(dict): |
|
"""This class is a Registry that stores information about the Accelerators. |
|
|
|
The Accelerators are mapped to strings. These strings are names that identify |
|
an accelerator, e.g., "gpu". It also returns Optional description and |
|
parameters to initialize the Accelerator, which were defined during the |
|
registration. |
|
|
|
The motivation for having a AcceleratorRegistry is to make it convenient |
|
for the Users to try different accelerators by passing mapped aliases |
|
to the accelerator flag to the Trainer. |
|
|
|
Example:: |
|
|
|
@AcceleratorRegistry.register("sota", description="Custom sota accelerator", a=1, b=True) |
|
class SOTAAccelerator(Accelerator): |
|
def __init__(self, a, b): |
|
... |
|
|
|
or |
|
|
|
AcceleratorRegistry.register("sota", SOTAAccelerator, description="Custom sota accelerator", a=1, b=True) |
|
|
|
""" |
|
|
|
def register( |
|
self, |
|
name: str, |
|
accelerator: Optional[Callable] = None, |
|
description: str = "", |
|
override: bool = False, |
|
**init_params: Any, |
|
) -> Callable: |
|
"""Registers a accelerator mapped to a name and with required metadata. |
|
|
|
Args: |
|
name : the name that identifies a accelerator, e.g. "gpu" |
|
accelerator : accelerator class |
|
description : accelerator description |
|
override : overrides the registered accelerator, if True |
|
init_params: parameters to initialize the accelerator |
|
|
|
""" |
|
if not (name is None or isinstance(name, str)): |
|
raise TypeError(f"`name` must be a str, found {name}") |
|
|
|
if name in self and not override: |
|
raise MisconfigurationException(f"'{name}' is already present in the registry. HINT: Use `override=True`.") |
|
|
|
data: dict[str, Any] = {} |
|
|
|
data["description"] = description |
|
data["init_params"] = init_params |
|
|
|
def do_register(name: str, accelerator: Callable) -> Callable: |
|
data["accelerator"] = accelerator |
|
data["accelerator_name"] = name |
|
self[name] = data |
|
return accelerator |
|
|
|
if accelerator is not None: |
|
return do_register(name, accelerator) |
|
|
|
return do_register |
|
|
|
@override |
|
def get(self, name: str, default: Optional[Any] = None) -> Any: |
|
"""Calls the registered accelerator with the required parameters and returns the accelerator object. |
|
|
|
Args: |
|
name (str): the name that identifies a accelerator, e.g. "gpu" |
|
|
|
""" |
|
if name in self: |
|
data = self[name] |
|
return data["accelerator"](**data["init_params"]) |
|
|
|
if default is not None: |
|
return default |
|
|
|
err_msg = "'{}' not found in registry. Available names: {}" |
|
available_names = self.available_accelerators() |
|
raise KeyError(err_msg.format(name, available_names)) |
|
|
|
def remove(self, name: str) -> None: |
|
"""Removes the registered accelerator by name.""" |
|
self.pop(name) |
|
|
|
def available_accelerators(self) -> set[str]: |
|
"""Returns a set of registered accelerators.""" |
|
return set(self.keys()) |
|
|
|
def __str__(self) -> str: |
|
return "Registered Accelerators: {}".format(", ".join(self.available_accelerators())) |
|
|
|
|
|
def call_register_accelerators(registry: _AcceleratorRegistry, base_module: str) -> None: |
|
"""Legacy. |
|
|
|
Do not use. |
|
|
|
""" |
|
import importlib |
|
|
|
module = importlib.import_module(base_module) |
|
from lightning_fabric.accelerators.accelerator import Accelerator |
|
|
|
_register_classes(registry, "register_accelerators", module, Accelerator) |
|
|