File size: 3,614 Bytes
a0e2cb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os, sys
import torch, torchaudio
import argparse
import json
from omegaconf import MISSING, OmegaConf,DictConfig
from huggingface_hub import hf_hub_download

os.environ['DISABLE_FLASH_ATTN'] = "1"
from SongBloom.models.songbloom.songbloom_pl import SongBloom_Sampler


def hf_download(repo_id="CypressYang/SongBloom", model_name="songbloom_full_150s", local_dir="./cache", **kwargs):
    cfg_path = hf_hub_download(
        repo_id=repo_id, filename=f"{model_name}.yaml", local_dir=local_dir, **kwargs)
    ckpt_path = hf_hub_download(
        repo_id=repo_id, filename=f"{model_name}.pt", local_dir=local_dir, **kwargs)
    
    vae_cfg_path = hf_hub_download(
        repo_id=repo_id, filename="stable_audio_1920_vae.json", local_dir=local_dir, **kwargs)
    vae_ckpt_path = hf_hub_download(
        repo_id=repo_id, filename="autoencoder_music_dsp1920.ckpt", local_dir=local_dir, **kwargs)
    
    g2p_path = hf_hub_download(
        repo_id=repo_id, filename="vocab_g2p.yaml", local_dir=local_dir, **kwargs)
    

    
    


def load_config(cfg_file, parent_dir="./") -> DictConfig:
    OmegaConf.register_new_resolver("eval", lambda x: eval(x))
    OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
    OmegaConf.register_new_resolver("get_fname", lambda x: os.path.splitext(os.path.basename(x))[0])
    OmegaConf.register_new_resolver("load_yaml", lambda x: OmegaConf.load(x))
    OmegaConf.register_new_resolver("dynamic_path", lambda x: x.replace("???", parent_dir))
    # cmd_cfg = OmegaConf.from_cli()
    
    file_cfg = OmegaConf.load(open(cfg_file, 'r')) if cfg_file is not None \
                else OmegaConf.create()
    

    return file_cfg



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--repo-id", type=str, default="CypressYang/SongBloom")
    parser.add_argument("--model-name", type=str, default="songbloom_full_150s")
    parser.add_argument("--local-dir", type=str, default="./cache")
    parser.add_argument("--input-jsonl", type=str, required=True)
    parser.add_argument("--output-dir", type=str, default="./output")
    parser.add_argument("--n-samples", type=int, default=2)
    parser.add_argument("--dtype", type=str, default='float32', choices=['float32', 'bfloat16'])
    
    args = parser.parse_args()

    hf_download(args.repo_id, args.model_name, args.local_dir)
    cfg = load_config(f"{args.local_dir}/{args.model_name}.yaml", parent_dir=args.local_dir)
  
    dtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
    model = SongBloom_Sampler.build_from_trainer(cfg, strict=True, dtype=dtype)
    model.set_generation_params(**cfg.inference)
          
    os.makedirs(args.output_dir, exist_ok=True)
    
    input_lines = open(args.input_jsonl, 'r').readlines()
    input_lines = [json.loads(l.strip()) for l in input_lines]
    
    for test_sample in input_lines:
        # print(test_sample)
        idx, lyrics, prompt_wav = test_sample["idx"], test_sample["lyrics"], test_sample["prompt_wav"]

        prompt_wav, sr = torchaudio.load(prompt_wav)
        if sr != model.sample_rate:
            prompt_wav = torchaudio.functional.resample(prompt_wav, sr, model.sample_rate)
        prompt_wav = prompt_wav.mean(dim=0, keepdim=True).to(dtype)
        prompt_wav = prompt_wav[..., :10*model.sample_rate]
        # breakpoint()
        for i in range(args.n_samples):
            wav = model.generate(lyrics, prompt_wav)
            torchaudio.save(f'{args.output_dir}/{idx}_s{i}.flac', wav[0].cpu().float(), model.sample_rate)


if __name__ == "__main__":
    
    main()