sleeper371's picture
add code
37a9836
raw
history blame
5.84 kB
import logging
import sys
from typing import Optional, Literal
import os
import shutil
from zipfile import ZipFile
from pathlib import Path
from huggingface_hub import hf_hub_download, upload_file
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
__all__ = ["download_dataset_from_hf", "upload_file_to_hf", "download_file_from_hf"]
def download_dataset_from_hf(
repo_id: str,
filename: str,
dest_path: str,
token: str = None,
local_dir: str = "./downloads",
remove_downloaded_file: bool = True,
) -> None:
"""
Download a file from Hugging Face repository and unzip it to destination path
Args:
repo_id (str): Hugging Face repository ID (username/repo_name)
filename (str): Name of the file to download from the repository
dest_path (str): Destination path where contents will be unzipped
token (str, optional): Hugging Face token, if None will prompt for login
"""
# Ensure destination directory exists
os.makedirs(dest_path, exist_ok=True)
if token is None:
logger.info("reading HF_TOKEN variable from environment")
token = os.getenv("HF_TOKEN")
# Download the file
downloaded_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="dataset", # Specify dataset repository
local_dir=local_dir, # Temporary download location
token=token,
)
logger.info(f"Downloaded {filename} to {downloaded_file}")
# Check if it's a zip file
if filename.endswith(".zip"):
# Extract the zip file
with ZipFile(downloaded_file, "r") as zip_ref:
zip_ref.extractall(dest_path)
logger.info(f"Unzipped contents to {dest_path}")
# Clean up the downloaded zip file
if remove_downloaded_file:
os.remove(downloaded_file)
logger.info(f"Cleaned up temporary file: {downloaded_file}")
else:
# If not a zip, just move the file
final_path = os.path.join(dest_path, filename)
shutil.move(downloaded_file, final_path)
logger.info(f"Moved {filename} to {final_path}")
def download_file_from_hf(
repo_id: str,
repo_type: Literal["model", "dataset"],
filename: str,
dest_path: str,
token: str = None,
) -> None:
"""
Download a file from Hugging Face repository and unzip it to destination path
Args:
repo_id (str): Hugging Face repository ID (username/repo_name)
repo_type: model for model repo, dataset for dataset repo
filename (str): Name of the file to download from the repository
dest_path (str): Destination path where contents will be unzipped
token (str, optional): Hugging Face token, if None will prompt for login
"""
# Ensure destination directory exists
os.makedirs(dest_path, exist_ok=True)
if token is None:
logger.info("reading HF_TOKEN variable from environment")
token = os.getenv("HF_TOKEN")
# Download the file
downloaded_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type=repo_type,
local_dir="./downloads", # Temporary download location
token=token,
)
logger.info(f"Downloaded {filename} to {downloaded_file}")
# Check if it's a zip file
if filename.endswith(".zip"):
# Extract the zip file
with ZipFile(downloaded_file, "r") as zip_ref:
zip_ref.extractall(dest_path)
logger.info(f"Unzipped contents to {dest_path}")
# Clean up the downloaded zip file
os.remove(downloaded_file)
logger.info(f"Cleaned up temporary file: {downloaded_file}")
else:
# If not a zip, just move the file
final_path = os.path.join(dest_path, filename)
shutil.move(downloaded_file, final_path)
logger.info(f"Moved {filename} to {final_path}")
def upload_file_to_hf(
local_file_path: str,
repo_id: str,
repo_type: Literal["model", "dataset"],
token: Optional[str] = None,
path_in_repo: Optional[str] = None,
commit_message: str = "Upload file",
) -> None:
"""
Upload a file to Hugging Face hub.
Args:
local_file_path (str): Path to the local .pt checkpoint file
repo_id (str): Repository ID in format "username/repo_name"
repo_type (str, optional): Type of repository, either "model" or "dataset"
token (str): Hugging Face authentication token. Read from environment variable HF_TOKEN if don't provide
path_in_repo (str, optional): Destination path in the repository.
Defaults to the filename from local_checkpoint_path
commit_message (str, optional): Commit message for the upload
Raises:
FileNotFoundError: If the checkpoint file doesn't exist
ValueError: If the repository ID is invalid
"""
# Validate file exists
if not os.path.isfile(local_file_path):
raise FileNotFoundError(f"File not found: {local_file_path}")
# Use filename as default path_in_repo if not specified
if path_in_repo is None:
path_in_repo = Path(local_file_path).name
if token is None:
logger.info("reading HF_TOKEN variable from environment")
token = os.getenv("HF_TOKEN")
if token is None:
raise RuntimeError("not found HF_TOKEN variable from environment")
upload_file(
path_or_fileobj=local_file_path,
path_in_repo=path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
token=token,
commit_message=commit_message,
)
logger.info(f"Successfully uploaded {local_file_path} to {repo_id}/{path_in_repo}")