Gregniuki commited on
Commit
804231d
·
verified ·
1 Parent(s): dbcc874

Delete infer/infer_cli.py

Browse files
Files changed (1) hide show
  1. infer/infer_cli.py +0 -226
infer/infer_cli.py DELETED
@@ -1,226 +0,0 @@
1
- import argparse
2
- import codecs
3
- import os
4
- import re
5
- from importlib.resources import files
6
- from pathlib import Path
7
-
8
- import numpy as np
9
- import soundfile as sf
10
- import tomli
11
- from cached_path import cached_path
12
-
13
- from infer.utils_infer import (
14
- infer_process,
15
- load_model,
16
- load_vocoder,
17
- preprocess_ref_audio_text,
18
- remove_silence_for_generated_wav,
19
- )
20
- from model import DiT, UNetT
21
-
22
- parser = argparse.ArgumentParser(
23
- prog="python3 infer-cli.py",
24
- description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
25
- epilog="Specify options above to override one or more settings from config.",
26
- )
27
- parser.add_argument(
28
- "-c",
29
- "--config",
30
- help="Configuration file. Default=infer/examples/basic/basic.toml",
31
- default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
32
- )
33
- parser.add_argument(
34
- "-m",
35
- "--model",
36
- help="F5-TTS | E2-TTS",
37
- )
38
- parser.add_argument(
39
- "-p",
40
- "--ckpt_file",
41
- help="The Checkpoint .pt",
42
- )
43
- parser.add_argument(
44
- "-v",
45
- "--vocab_file",
46
- help="The vocab .txt",
47
- )
48
- parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
49
- parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
50
- parser.add_argument(
51
- "-t",
52
- "--gen_text",
53
- type=str,
54
- help="Text to generate.",
55
- )
56
- parser.add_argument(
57
- "-f",
58
- "--gen_file",
59
- type=str,
60
- help="File with text to generate. Ignores --gen_text",
61
- )
62
- parser.add_argument(
63
- "-o",
64
- "--output_dir",
65
- type=str,
66
- help="Path to output folder..",
67
- )
68
- parser.add_argument(
69
- "-w",
70
- "--output_file",
71
- type=str,
72
- help="Filename of output file..",
73
- )
74
- parser.add_argument(
75
- "--remove_silence",
76
- help="Remove silence.",
77
- )
78
- parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
79
- parser.add_argument(
80
- "--load_vocoder_from_local",
81
- action="store_true",
82
- help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
83
- )
84
- parser.add_argument(
85
- "--speed",
86
- type=float,
87
- default=1.0,
88
- help="Adjust the speed of the audio generation (default: 1.0)",
89
- )
90
- args = parser.parse_args()
91
-
92
- config = tomli.load(open(args.config, "rb"))
93
-
94
- ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
95
- ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
96
- gen_text = args.gen_text if args.gen_text else config["gen_text"]
97
- gen_file = args.gen_file if args.gen_file else config["gen_file"]
98
-
99
- # patches for pip pkg user
100
- if "infer/examples/" in ref_audio:
101
- ref_audio = str(files("f5_tts").joinpath(f"{ref_audio}"))
102
- if "infer/examples/" in gen_file:
103
- gen_file = str(files("f5_tts").joinpath(f"{gen_file}"))
104
- if "voices" in config:
105
- for voice in config["voices"]:
106
- voice_ref_audio = config["voices"][voice]["ref_audio"]
107
- if "infer/examples/" in voice_ref_audio:
108
- config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
109
-
110
- if gen_file:
111
- gen_text = codecs.open(gen_file, "r", "utf-8").read()
112
- output_dir = args.output_dir if args.output_dir else config["output_dir"]
113
- output_file = args.output_file if args.output_file else config["output_file"]
114
- model = args.model if args.model else config["model"]
115
- ckpt_file = args.ckpt_file if args.ckpt_file else ""
116
- vocab_file = args.vocab_file if args.vocab_file else ""
117
- remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
118
- speed = args.speed
119
-
120
- wave_path = Path(output_dir) / output_file
121
- # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
122
-
123
- vocoder_name = args.vocoder_name
124
- mel_spec_type = args.vocoder_name
125
- if vocoder_name == "vocos":
126
- vocoder_local_path = "../checkpoints/vocos-mel-24khz"
127
- elif vocoder_name == "bigvgan":
128
- vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
129
-
130
- vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
131
-
132
-
133
- # load models
134
- if model == "F5-TTS":
135
- model_cls = DiT
136
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
137
- if ckpt_file == "":
138
- if vocoder_name == "vocos":
139
- repo_name = "F5-TTS"
140
- exp_name = "F5TTS_Base"
141
- ckpt_step = 1200000
142
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
143
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
144
- elif vocoder_name == "bigvgan":
145
- repo_name = "F5-TTS"
146
- exp_name = "F5TTS_Base_bigvgan"
147
- ckpt_step = 1250000
148
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
149
-
150
- elif model == "E2-TTS":
151
- assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
152
- model_cls = UNetT
153
- model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
154
- if ckpt_file == "":
155
- repo_name = "E2-TTS"
156
- exp_name = "E2TTS_Base"
157
- ckpt_step = 1200000
158
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
159
- # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
160
-
161
-
162
- print(f"Using {model}...")
163
- ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
164
-
165
-
166
- def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
167
- main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
168
- if "voices" not in config:
169
- voices = {"main": main_voice}
170
- else:
171
- voices = config["voices"]
172
- voices["main"] = main_voice
173
- for voice in voices:
174
- voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
175
- voices[voice]["ref_audio"], voices[voice]["ref_text"]
176
- )
177
- print("Voice:", voice)
178
- print("Ref_audio:", voices[voice]["ref_audio"])
179
- print("Ref_text:", voices[voice]["ref_text"])
180
-
181
- generated_audio_segments = []
182
- reg1 = r"(?=\[\w+\])"
183
- chunks = re.split(reg1, text_gen)
184
- reg2 = r"\[(\w+)\]"
185
- for text in chunks:
186
- if not text.strip():
187
- continue
188
- match = re.match(reg2, text)
189
- if match:
190
- voice = match[1]
191
- else:
192
- print("No voice tag found, using main.")
193
- voice = "main"
194
- if voice not in voices:
195
- print(f"Voice {voice} not found, using main.")
196
- voice = "main"
197
- text = re.sub(reg2, "", text)
198
- gen_text = text.strip()
199
- ref_audio = voices[voice]["ref_audio"]
200
- ref_text = voices[voice]["ref_text"]
201
- print(f"Voice: {voice}")
202
- audio, final_sample_rate, spectragram = infer_process(
203
- ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
204
- )
205
- generated_audio_segments.append(audio)
206
-
207
- if generated_audio_segments:
208
- final_wave = np.concatenate(generated_audio_segments)
209
-
210
- if not os.path.exists(output_dir):
211
- os.makedirs(output_dir)
212
-
213
- with open(wave_path, "wb") as f:
214
- sf.write(f.name, final_wave, final_sample_rate)
215
- # Remove silence
216
- if remove_silence:
217
- remove_silence_for_generated_wav(f.name)
218
- print(f.name)
219
-
220
-
221
- def main():
222
- main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
223
-
224
-
225
- if __name__ == "__main__":
226
- main()