jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to data saving/loading."""
import errno
import io
import logging
from pathlib import Path
from typing import IO, Any, Union
import fsspec
import fsspec.utils
import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import AbstractFileSystem
from lightning_utilities.core.imports import module_available
from lightning_fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
log = logging.getLogger(__name__)
def _load(
path_or_url: Union[IO, _PATH],
map_location: _MAP_LOCATION_TYPE = None,
weights_only: bool = False,
) -> Any:
"""Loads a checkpoint.
Args:
path_or_url: Path or URL of the checkpoint.
map_location: a function, ``torch.device``, string or a dict specifying how to remap storage locations.
"""
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similar
return torch.load(
path_or_url,
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
weights_only=weights_only,
)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type]
weights_only=weights_only,
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(
f,
map_location=map_location, # type: ignore[arg-type]
weights_only=weights_only,
)
def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
fs, _ = url_to_fs(str(path), **kwargs)
return fs
def _atomic_save(checkpoint: dict[str, Any], filepath: Union[str, Path]) -> None:
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
Args:
checkpoint: The object to save.
Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
accepts.
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
"""
bytesbuffer = io.BytesIO()
log.debug(f"Saving checkpoint: {filepath}")
torch.save(checkpoint, bytesbuffer)
try:
# We use a transaction here to avoid file corruption if the save gets interrupted
fs, urlpath = fsspec.core.url_to_fs(str(filepath))
with fs.transaction, fs.open(urlpath, "wb") as f:
f.write(bytesbuffer.getvalue())
except PermissionError as e:
if isinstance(e.__context__, OSError) and getattr(e.__context__, "errno", None) == errno.EXDEV:
raise RuntimeError(
'Upgrade fsspec to enable cross-device local checkpoints: pip install "fsspec[http]>=2025.5.0"',
) from e
def _is_object_storage(fs: AbstractFileSystem) -> bool:
if module_available("adlfs"):
from adlfs import AzureBlobFileSystem
if isinstance(fs, AzureBlobFileSystem):
return True
if module_available("gcsfs"):
from gcsfs import GCSFileSystem
if isinstance(fs, GCSFileSystem):
return True
if module_available("s3fs"):
from s3fs import S3FileSystem
if isinstance(fs, S3FileSystem):
return True
return False
def _is_dir(fs: AbstractFileSystem, path: Union[str, Path], strict: bool = False) -> bool:
"""Check if a path is directory-like.
This function determines if a given path is considered directory-like, taking into account the behavior
specific to object storage platforms. For other filesystems, it behaves similarly to the standard `fs.isdir`
method.
Args:
fs: The filesystem to check the path against.
path: The path or URL to be checked.
strict: A flag specific to Object Storage platforms. If set to ``False``, any non-existing path is considered
as a valid directory-like path. In such cases, the directory (and any non-existing parent directories)
will be created on the fly. Defaults to False.
"""
# Object storage fsspec's are inconsistent with other file systems because they do not have real directories,
# see for instance https://gcsfs.readthedocs.io/en/latest/api.html?highlight=makedirs#gcsfs.core.GCSFileSystem.mkdir
# In particular, `fs.makedirs` is a no-op so we use `strict=False` to consider any path as valid, except if the
# path already exists but is a file
if _is_object_storage(fs):
if strict:
return fs.isdir(path)
# Check if the path is not already taken by a file. If not, it is considered a valid directory-like path
# because the directory (and all non-existing parent directories) will be created on the fly.
return not fs.isfile(path)
return fs.isdir(path)
def _is_local_file_protocol(path: _PATH) -> bool:
return fsspec.utils.get_protocol(str(path)) == "file"