File size: 1,692 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Root package info."""

import logging
import os
import sys

from lightning_utilities.core.imports import package_available

if os.path.isfile(os.path.join(os.path.dirname(__file__), "__about__.py")):
    from lightning_fabric.__about__ import *  # noqa: F403
if os.path.isfile(os.path.join(os.path.dirname(__file__), "__version__.py")):
    from lightning_fabric.__version__ import version as __version__
elif package_available("lightning"):
    from lightning_fabric import __version__  # noqa: F401

_root_logger = logging.getLogger()
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)

if not _root_logger.hasHandlers():
    _logger.addHandler(logging.StreamHandler())
    _logger.propagate = False


# Setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
# to use an NVML-based implementation that doesn't poison forks.
# https://github.com/pytorch/pytorch/issues/83973
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"

# see https://github.com/pytorch/pytorch/issues/139990
if sys.platform == "win32":
    os.environ["USE_LIBUV"] = "0"


from lightning_fabric.fabric import Fabric  # noqa: E402
from lightning_fabric.utilities.seed import seed_everything  # noqa: E402
from lightning_fabric.utilities.warnings import disable_possible_user_warnings  # noqa: E402
from lightning_fabric.wrappers import is_wrapped  # noqa: E402

# this import needs to go last as it will patch other modules
import lightning_fabric._graveyard  # noqa: E402, F401  # isort: skip

__all__ = ["Fabric", "seed_everything", "is_wrapped"]


if os.environ.get("POSSIBLE_USER_WARNINGS", "").lower() in ("0", "off"):
    disable_possible_user_warnings()