File size: 8,413 Bytes
37a9836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import argparse
import logging
import os
from dataclasses import asdict
import torch
import torch.nn as nn
from core.trainer import train_hubert_quantizer
from core.model.hubert import (
    HuBERTForBarkSemantic,
    HubertForBarkSemanticConfig,
)
from core.utils import download_dataset_from_hf
from core.bark.constants import HUBERT_OUTPUT_VOCAB_SIZE


# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
WORKSPACE = "./"

# HF repo id to the dataset
DATASET_REPO_ID = "sleeper371/bark-wave-semantic"
# if choose to publish checkpoint to HF, this will be the repo-id to publish checkpoint
CHECKPOINT_REPO_ID = "sleeper371/hubert-for-bark-semantic"
# name of the noise data file on the HF dataset repo
HF_NOISE_FILE_NAME = "environmental_sound.zip"


# local path that has the noise data use to enhance the training data
_LOCAL_NOISE_DATA_PATH = "noise_dataset"
# local path to the training audio folder
_LOCAL_TRAINING_DATA_PATH = "wav_semantic_dataset"
# local folder path to save trained checkpoint
_LOCAL_CHECKPOINTS_PATH = "checkpoints"


def prefix_workspace(workspace_path: str, path: str) -> str:
    return os.path.join(workspace_path, path)


def parse_args():
    parser = argparse.ArgumentParser(description="HuBERT Training Script")
    parser.add_argument(
        "--hubert-checkpoint-name",
        type=str,
        default="facebook/hubert-base-ls960",
        help="checkpoint name that will be used as the feature extractor layer for CustomHuBERT",
    )
    parser.add_argument(
        "--feature-layer",
        type=int,
        default=11,
        help="layer at which to use features for the LSTM",
    )

    parser.add_argument(
        "--mix-precision",
        action="store_true",
        help="train model with mix precision bfloat16 and gradient scaler",
    )

    parser.add_argument(
        "--lr", type=float, default=8e-5, help="Learning rate (default: 8e-5)"
    )
    parser.add_argument(
        "--num-epochs",
        type=int,
        default=3,
        help="Number of training epochs (default: 3)",
    )
    parser.add_argument(
        "--train-ratio",
        type=float,
        default=0.8,
        help="Train/validation split ratio (default: 0.8)",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=2,
        help="Batch size for training (default: 16)",
    )
    parser.add_argument(
        "--dataset-file-name",
        type=str,
        default="short_sentences.zip",
        help="name of the dataset file in the HF repo to download",
    )

    parser.add_argument(
        "--save-checkpoint-every",
        type=int,
        default=1,
        help="Save checkpoint every N epochs (default: 1)",
    )

    parser.add_argument(
        "--model-bfloat16",
        action="store_true",
        default=False,
        help="set true to convert and train model in bfloat16",
    )

    parser.add_argument(
        "--augment-data-with-noise",
        action="store_true",
        default=False,
        help="load and add noise randomly to training data as a regularization technique",
    )

    parser.add_argument(
        "--augment-prob",
        type=float,
        default=0.5,
        help="noise will be added to audio sample with this probability",
    )

    parser.add_argument(
        "--publish-hf",
        action="store_true",
        default=False,
        help="if set, publish checkpoints to huggingface hub",
    )

    parser.add_argument(
        "--workspace",
        type=str,
        default=WORKSPACE,
        help="workspace folder to store data",
    )

    parser.add_argument(
        "--num_samples",
        type=int,
        default=10000,
        help="number of examples to load from the dataset",
    )

    return parser.parse_args()


def ensure_directory(path: str):
    """Create directory if it doesn't exist."""
    os.makedirs(path, exist_ok=True)


def calculate_model_memory(model: nn.Module):
    """
    Calculate and print the memory usage of a PyTorch model's parameters based on their detected data type.

    Args:
        model (nn.Module): The PyTorch model to analyze.
    """
    # Dictionary mapping PyTorch dtypes to bytes per parameter
    bytes_per_param_dict = {
        torch.float32: 4,  # 32 bits = 4 bytes
        torch.float16: 2,  # 16 bits = 2 bytes
        torch.int8: 1,  # 8 bits = 1 byte
        torch.int32: 4,  # 32 bits = 4 bytes
        torch.int64: 8,  # 64 bits = 8 bytes
    }

    # Detect the data type from the first parameter
    param_iter = iter(model.parameters())
    try:
        first_param = next(param_iter)
        dtype = first_param.dtype
    except StopIteration:
        print("Model has no parameters!")
        return

    # Get bytes per parameter based on detected dtype
    # Default to 4 bytes if dtype not found
    bytes_per_param = bytes_per_param_dict.get(dtype, 4)
    dtype_name = str(dtype).replace("torch.", "")  # Clean up dtype name for printing

    # Count total number of parameters
    total_params = sum(p.numel() for p in model.parameters())
    # Count total number of parameters
    total_params = sum(p.numel() for p in model.parameters())

    # Calculate total memory in bytes
    total_memory_bytes = total_params * bytes_per_param

    # Convert to KB, MB, and GB for readability
    total_memory_kb = total_memory_bytes / 1024
    total_memory_mb = total_memory_kb / 1024
    total_memory_gb = total_memory_mb / 1024

    # Print results
    logger.info(f"Model Memory Usage (Detected dtype: {dtype_name}):")
    logger.info(f"Total Parameters: {total_params:,}")
    logger.info(f"Total Memory: {total_memory_gb:,.2f} GB")


def main():
    args = parse_args()

    # local path that has the noise data use to enhance the training data
    LOCAL_NOISE_DATA_PATH = prefix_workspace(args.workspace, _LOCAL_NOISE_DATA_PATH)
    # local path to the training audio folder
    LOCAL_TRAINING_DATA_PATH = prefix_workspace(
        args.workspace, _LOCAL_TRAINING_DATA_PATH
    )
    # local folder path to save trained checkpoint
    LOCAL_CHECKPOINTS_PATH = prefix_workspace(args.workspace, _LOCAL_CHECKPOINTS_PATH)

    # Create necessary directories
    ensure_directory(LOCAL_CHECKPOINTS_PATH)

    logger.info("Starting HuBERT training")

    device = (
        torch.device("cuda")
        if torch.cuda.is_available()
        else (
            torch.device("mps")
            if torch.backends.mps.is_available()
            else torch.device("cpu")
        )
    )

    config = HubertForBarkSemanticConfig(
        vocab_size=HUBERT_OUTPUT_VOCAB_SIZE,
        checkpoint_name=args.hubert_checkpoint_name,
        feature_layer=args.feature_layer,
        num_decoder_layer=6,
    )
    model = HuBERTForBarkSemantic(
        config=config, load_hubert_pretrained_weights=True, device=device
    )

    if args.model_bfloat16:
        model = model.to(torch.bfloat16)
        logger.info("Training model in bfloat16 precision")

    calculate_model_memory(model)

    # Download datasets if needed
    if not os.path.exists(LOCAL_TRAINING_DATA_PATH):
        download_dataset_from_hf(
            DATASET_REPO_ID,
            args.dataset_file_name,
            LOCAL_TRAINING_DATA_PATH,
        )

    if args.augment_data_with_noise and not os.path.exists(LOCAL_NOISE_DATA_PATH):
        download_dataset_from_hf(
            DATASET_REPO_ID,
            HF_NOISE_FILE_NAME,
            LOCAL_NOISE_DATA_PATH,
        )

    # Train the model
    trained_model = train_hubert_quantizer(
        model=model,
        model_config=asdict(config),
        lr=args.lr,
        num_epoch=args.num_epochs,
        train_ratio=args.train_ratio,
        batch_size=args.batch_size,
        data_path=LOCAL_TRAINING_DATA_PATH,
        checkpoint_path=LOCAL_CHECKPOINTS_PATH,
        save_checkpoint_every=args.save_checkpoint_every,
        augment_data_with_noise=args.augment_data_with_noise,
        augment_prob=args.augment_prob,
        noise_data_path=LOCAL_NOISE_DATA_PATH,
        publish_hf=args.publish_hf,
        publish_to_repo=CHECKPOINT_REPO_ID,
        device=device,
        num_samples=args.num_samples,
        enable_grad_scaler=args.mix_precision,
    )
    logger.info("Training completed")

    return trained_model


if __name__ == "__main__":
    main()