File size: 5,840 Bytes
37a9836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import logging
import sys
from typing import Optional, Literal
import os
import shutil
from zipfile import ZipFile
from pathlib import Path
from huggingface_hub import hf_hub_download, upload_file

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

__all__ = ["download_dataset_from_hf", "upload_file_to_hf", "download_file_from_hf"]


def download_dataset_from_hf(
    repo_id: str,
    filename: str,
    dest_path: str,
    token: str = None,
    local_dir: str = "./downloads",
    remove_downloaded_file: bool = True,
) -> None:
    """
    Download a file from Hugging Face repository and unzip it to destination path

    Args:
        repo_id (str): Hugging Face repository ID (username/repo_name)
        filename (str): Name of the file to download from the repository
        dest_path (str): Destination path where contents will be unzipped
        token (str, optional): Hugging Face token, if None will prompt for login
    """
    # Ensure destination directory exists
    os.makedirs(dest_path, exist_ok=True)
    if token is None:
        logger.info("reading HF_TOKEN variable from environment")
        token = os.getenv("HF_TOKEN")

    # Download the file
    downloaded_file = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        repo_type="dataset",  # Specify dataset repository
        local_dir=local_dir,  # Temporary download location
        token=token,
    )
    logger.info(f"Downloaded {filename} to {downloaded_file}")

    # Check if it's a zip file
    if filename.endswith(".zip"):
        # Extract the zip file
        with ZipFile(downloaded_file, "r") as zip_ref:
            zip_ref.extractall(dest_path)
        logger.info(f"Unzipped contents to {dest_path}")

        # Clean up the downloaded zip file
        if remove_downloaded_file:
            os.remove(downloaded_file)
            logger.info(f"Cleaned up temporary file: {downloaded_file}")
    else:
        # If not a zip, just move the file
        final_path = os.path.join(dest_path, filename)
        shutil.move(downloaded_file, final_path)
        logger.info(f"Moved {filename} to {final_path}")


def download_file_from_hf(
    repo_id: str,
    repo_type: Literal["model", "dataset"],
    filename: str,
    dest_path: str,
    token: str = None,
) -> None:
    """
    Download a file from Hugging Face repository and unzip it to destination path

    Args:
        repo_id (str): Hugging Face repository ID (username/repo_name)
        repo_type: model for model repo, dataset for dataset repo
        filename (str): Name of the file to download from the repository
        dest_path (str): Destination path where contents will be unzipped
        token (str, optional): Hugging Face token, if None will prompt for login

    """
    # Ensure destination directory exists
    os.makedirs(dest_path, exist_ok=True)
    if token is None:
        logger.info("reading HF_TOKEN variable from environment")
        token = os.getenv("HF_TOKEN")

    # Download the file
    downloaded_file = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        repo_type=repo_type,
        local_dir="./downloads",  # Temporary download location
        token=token,
    )
    logger.info(f"Downloaded {filename} to {downloaded_file}")

    # Check if it's a zip file
    if filename.endswith(".zip"):
        # Extract the zip file
        with ZipFile(downloaded_file, "r") as zip_ref:
            zip_ref.extractall(dest_path)
        logger.info(f"Unzipped contents to {dest_path}")

        # Clean up the downloaded zip file
        os.remove(downloaded_file)
        logger.info(f"Cleaned up temporary file: {downloaded_file}")
    else:
        # If not a zip, just move the file
        final_path = os.path.join(dest_path, filename)
        shutil.move(downloaded_file, final_path)
        logger.info(f"Moved {filename} to {final_path}")


def upload_file_to_hf(
    local_file_path: str,
    repo_id: str,
    repo_type: Literal["model", "dataset"],
    token: Optional[str] = None,
    path_in_repo: Optional[str] = None,
    commit_message: str = "Upload file",
) -> None:
    """
    Upload a file to Hugging Face hub.

    Args:
        local_file_path (str): Path to the local .pt checkpoint file
        repo_id (str): Repository ID in format "username/repo_name"
        repo_type (str, optional): Type of repository, either "model" or "dataset"
        token (str): Hugging Face authentication token. Read from environment variable HF_TOKEN if don't provide
        path_in_repo (str, optional): Destination path in the repository.
            Defaults to the filename from local_checkpoint_path
        commit_message (str, optional): Commit message for the upload

    Raises:
        FileNotFoundError: If the checkpoint file doesn't exist
        ValueError: If the repository ID is invalid
    """
    # Validate file exists
    if not os.path.isfile(local_file_path):
        raise FileNotFoundError(f"File not found: {local_file_path}")

    # Use filename as default path_in_repo if not specified
    if path_in_repo is None:
        path_in_repo = Path(local_file_path).name

    if token is None:
        logger.info("reading HF_TOKEN variable from environment")
        token = os.getenv("HF_TOKEN")
        if token is None:
            raise RuntimeError("not found HF_TOKEN variable from environment")

    upload_file(
        path_or_fileobj=local_file_path,
        path_in_repo=path_in_repo,
        repo_id=repo_id,
        repo_type=repo_type,
        token=token,
        commit_message=commit_message,
    )
    logger.info(f"Successfully uploaded {local_file_path} to {repo_id}/{path_in_repo}")