|
import logging |
|
import sys |
|
from typing import Optional, Literal |
|
import os |
|
import shutil |
|
from zipfile import ZipFile |
|
from pathlib import Path |
|
from huggingface_hub import hf_hub_download, upload_file |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
__all__ = ["download_dataset_from_hf", "upload_file_to_hf", "download_file_from_hf"] |
|
|
|
|
|
def download_dataset_from_hf( |
|
repo_id: str, |
|
filename: str, |
|
dest_path: str, |
|
token: str = None, |
|
local_dir: str = "./downloads", |
|
remove_downloaded_file: bool = True, |
|
) -> None: |
|
""" |
|
Download a file from Hugging Face repository and unzip it to destination path |
|
|
|
Args: |
|
repo_id (str): Hugging Face repository ID (username/repo_name) |
|
filename (str): Name of the file to download from the repository |
|
dest_path (str): Destination path where contents will be unzipped |
|
token (str, optional): Hugging Face token, if None will prompt for login |
|
""" |
|
|
|
os.makedirs(dest_path, exist_ok=True) |
|
if token is None: |
|
logger.info("reading HF_TOKEN variable from environment") |
|
token = os.getenv("HF_TOKEN") |
|
|
|
|
|
downloaded_file = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
repo_type="dataset", |
|
local_dir=local_dir, |
|
token=token, |
|
) |
|
logger.info(f"Downloaded {filename} to {downloaded_file}") |
|
|
|
|
|
if filename.endswith(".zip"): |
|
|
|
with ZipFile(downloaded_file, "r") as zip_ref: |
|
zip_ref.extractall(dest_path) |
|
logger.info(f"Unzipped contents to {dest_path}") |
|
|
|
|
|
if remove_downloaded_file: |
|
os.remove(downloaded_file) |
|
logger.info(f"Cleaned up temporary file: {downloaded_file}") |
|
else: |
|
|
|
final_path = os.path.join(dest_path, filename) |
|
shutil.move(downloaded_file, final_path) |
|
logger.info(f"Moved {filename} to {final_path}") |
|
|
|
|
|
def download_file_from_hf( |
|
repo_id: str, |
|
repo_type: Literal["model", "dataset"], |
|
filename: str, |
|
dest_path: str, |
|
token: str = None, |
|
) -> None: |
|
""" |
|
Download a file from Hugging Face repository and unzip it to destination path |
|
|
|
Args: |
|
repo_id (str): Hugging Face repository ID (username/repo_name) |
|
repo_type: model for model repo, dataset for dataset repo |
|
filename (str): Name of the file to download from the repository |
|
dest_path (str): Destination path where contents will be unzipped |
|
token (str, optional): Hugging Face token, if None will prompt for login |
|
|
|
""" |
|
|
|
os.makedirs(dest_path, exist_ok=True) |
|
if token is None: |
|
logger.info("reading HF_TOKEN variable from environment") |
|
token = os.getenv("HF_TOKEN") |
|
|
|
|
|
downloaded_file = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
repo_type=repo_type, |
|
local_dir="./downloads", |
|
token=token, |
|
) |
|
logger.info(f"Downloaded {filename} to {downloaded_file}") |
|
|
|
|
|
if filename.endswith(".zip"): |
|
|
|
with ZipFile(downloaded_file, "r") as zip_ref: |
|
zip_ref.extractall(dest_path) |
|
logger.info(f"Unzipped contents to {dest_path}") |
|
|
|
|
|
os.remove(downloaded_file) |
|
logger.info(f"Cleaned up temporary file: {downloaded_file}") |
|
else: |
|
|
|
final_path = os.path.join(dest_path, filename) |
|
shutil.move(downloaded_file, final_path) |
|
logger.info(f"Moved {filename} to {final_path}") |
|
|
|
|
|
def upload_file_to_hf( |
|
local_file_path: str, |
|
repo_id: str, |
|
repo_type: Literal["model", "dataset"], |
|
token: Optional[str] = None, |
|
path_in_repo: Optional[str] = None, |
|
commit_message: str = "Upload file", |
|
) -> None: |
|
""" |
|
Upload a file to Hugging Face hub. |
|
|
|
Args: |
|
local_file_path (str): Path to the local .pt checkpoint file |
|
repo_id (str): Repository ID in format "username/repo_name" |
|
repo_type (str, optional): Type of repository, either "model" or "dataset" |
|
token (str): Hugging Face authentication token. Read from environment variable HF_TOKEN if don't provide |
|
path_in_repo (str, optional): Destination path in the repository. |
|
Defaults to the filename from local_checkpoint_path |
|
commit_message (str, optional): Commit message for the upload |
|
|
|
Raises: |
|
FileNotFoundError: If the checkpoint file doesn't exist |
|
ValueError: If the repository ID is invalid |
|
""" |
|
|
|
if not os.path.isfile(local_file_path): |
|
raise FileNotFoundError(f"File not found: {local_file_path}") |
|
|
|
|
|
if path_in_repo is None: |
|
path_in_repo = Path(local_file_path).name |
|
|
|
if token is None: |
|
logger.info("reading HF_TOKEN variable from environment") |
|
token = os.getenv("HF_TOKEN") |
|
if token is None: |
|
raise RuntimeError("not found HF_TOKEN variable from environment") |
|
|
|
upload_file( |
|
path_or_fileobj=local_file_path, |
|
path_in_repo=path_in_repo, |
|
repo_id=repo_id, |
|
repo_type=repo_type, |
|
token=token, |
|
commit_message=commit_message, |
|
) |
|
logger.info(f"Successfully uploaded {local_file_path} to {repo_id}/{path_in_repo}") |
|
|