|
import os |
|
import re |
|
|
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
|
|
from transformers import AutoTokenizer |
|
from modeling_asteroid import AsteroidTTSInstruct |
|
from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer |
|
|
|
MAX_CHANNELS = 8 |
|
SILENCE_DURATION = 5.0 |
|
|
|
def load_model(model_path, spt_config_path, spt_checkpoint_path): |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa") |
|
|
|
spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path) |
|
|
|
model.eval() |
|
spt.eval() |
|
return tokenizer, model, spt |
|
|
|
|
|
def process_jsonl_item(item): |
|
"""Process JSONL data items and extract audio and text information according to the new format""" |
|
base_path = item.get("base_path", "") |
|
text = item.get("text", "") |
|
|
|
|
|
if "prompt_audio" in item and "prompt_text" in item: |
|
print("Using prompt_audio and prompt_text directly from item.") |
|
|
|
prompt_audio = item["prompt_audio"] |
|
prompt_text = item["prompt_text"] |
|
|
|
|
|
if isinstance(prompt_audio, str) and base_path and prompt_audio: |
|
prompt_audio = os.path.join(base_path, prompt_audio) |
|
else: |
|
print("Using speaker1 and speaker2 information for prompt audio and text.") |
|
|
|
prompt_audio_speaker1 = item.get("prompt_audio_speaker1", "") |
|
prompt_text_speaker1 = item.get("prompt_text_speaker1", "") |
|
prompt_audio_speaker2 = item.get("prompt_audio_speaker2", "") |
|
prompt_text_speaker2 = item.get("prompt_text_speaker2", "") |
|
|
|
|
|
if isinstance(prompt_audio_speaker1, str): |
|
speaker1_audio = os.path.join(base_path, prompt_audio_speaker1) if base_path and prompt_audio_speaker1 else prompt_audio_speaker1 |
|
else: |
|
speaker1_audio = prompt_audio_speaker1 |
|
|
|
if isinstance(prompt_audio_speaker2, str): |
|
speaker2_audio = os.path.join(base_path, prompt_audio_speaker2) if base_path and prompt_audio_speaker2 else prompt_audio_speaker2 |
|
else: |
|
speaker2_audio = prompt_audio_speaker2 |
|
|
|
prompt_audio = { |
|
"speaker1": speaker1_audio, |
|
"speaker2": speaker2_audio |
|
} |
|
|
|
|
|
prompt_text = "" |
|
if prompt_text_speaker1: |
|
prompt_text += f"[S1]{prompt_text_speaker1}" |
|
if prompt_text_speaker2: |
|
prompt_text += f"[S2]{prompt_text_speaker2}" |
|
prompt_text = prompt_text.strip() |
|
|
|
return { |
|
"text": text, |
|
"prompt_text": prompt_text, |
|
"prompt_audio": prompt_audio |
|
} |
|
|
|
|
|
def load_audio_data(prompt_audio, target_sample_rate=16000): |
|
"""Load audio data and return processed audio tensor |
|
|
|
Args: |
|
prompt_audio: Can be in the following formats: |
|
- String: audio file path |
|
- Tuple: (wav, sr) result from torchaudio.load |
|
- Dict: {"speaker1": path_or_tuple, "speaker2": path_or_tuple} |
|
""" |
|
if prompt_audio is None: |
|
return None |
|
|
|
try: |
|
|
|
if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio: |
|
|
|
wav1, sr1 = _load_single_audio(prompt_audio["speaker1"]) |
|
wav2, sr2 = _load_single_audio(prompt_audio["speaker2"]) |
|
|
|
wav = merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate) |
|
if wav is None: |
|
return None |
|
else: |
|
|
|
wav, sr = _load_single_audio(prompt_audio) |
|
|
|
if sr != target_sample_rate: |
|
wav = torchaudio.functional.resample(wav, sr, target_sample_rate) |
|
|
|
if wav.shape[0] > 1: |
|
wav = wav.mean(dim=0, keepdim=True) |
|
if len(wav.shape) == 1: |
|
wav = wav.unsqueeze(0) |
|
|
|
return wav |
|
except Exception as e: |
|
print(f"Error loading audio data: {e}") |
|
raise |
|
|
|
|
|
def _load_single_audio(audio_input): |
|
"""Load single audio, supports file path or (wav, sr) tuple |
|
|
|
Args: |
|
audio_input: String (file path) or tuple (wav, sr) |
|
|
|
Returns: |
|
tuple: (wav, sr) |
|
""" |
|
if isinstance(audio_input, tuple) and len(audio_input) == 2: |
|
|
|
wav, sr = audio_input |
|
return wav, sr |
|
elif isinstance(audio_input, str): |
|
|
|
wav, sr = torchaudio.load(audio_input) |
|
return wav, sr |
|
else: |
|
raise ValueError(f"Unsupported audio input format: {type(audio_input)}") |
|
|
|
|
|
def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000): |
|
"""Merge audio data from two speakers""" |
|
try: |
|
|
|
if sr1 != target_sample_rate: |
|
wav1 = torchaudio.functional.resample(wav1, sr1, target_sample_rate) |
|
|
|
if wav1.shape[0] > 1: |
|
wav1 = wav1.mean(dim=0, keepdim=True) |
|
if len(wav1.shape) == 1: |
|
wav1 = wav1.unsqueeze(0) |
|
|
|
|
|
if sr2 != target_sample_rate: |
|
wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate) |
|
|
|
if wav2.shape[0] > 1: |
|
wav2 = wav2.mean(dim=0, keepdim=True) |
|
if len(wav2.shape) == 1: |
|
wav2 = wav2.unsqueeze(0) |
|
|
|
|
|
merged_wav = torch.cat([wav1, wav2], dim=1) |
|
return merged_wav |
|
except Exception as e: |
|
print(f"Error merging audio: {e}") |
|
raise |
|
|
|
|
|
def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, max_channels=8, pad_token=1024): |
|
seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>" |
|
inputs1 = np.array(tokenizer.encode(seq)) |
|
input_ids = np.full((inputs1.shape[0], max_channels), pad_token) |
|
input_ids[:, 0] = inputs1 |
|
|
|
if audio_data is not None: |
|
try: |
|
|
|
wav = audio_data |
|
|
|
|
|
silence_samples = int(SILENCE_DURATION * 16000) |
|
silence = torch.zeros(wav.shape[0], silence_samples) |
|
wav = torch.cat([wav, silence], dim=1) |
|
|
|
with torch.no_grad(): |
|
|
|
encode_result = spt.encode([wav.squeeze().to(device)]) |
|
audio_token = encode_result["codes_list"][0].permute(1, 0).cpu().numpy() |
|
|
|
|
|
audio_token[:, 0] = audio_token[:, 0] + 151665 |
|
input_ids = np.concatenate([input_ids, audio_token])[:-60] |
|
except Exception as e: |
|
print(f"Error processing audio data: {e}") |
|
raise |
|
|
|
return input_ids |
|
|
|
|
|
def shifting_inputs(input_ids, tokenizer, pad_token=1024, max_channels=8): |
|
seq_len = input_ids.shape[0] |
|
new_seq_len = seq_len + max_channels - 1 |
|
shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64) |
|
shifted_input_ids[:, 0] = np.full(new_seq_len, tokenizer.pad_token_id, dtype=np.int64) |
|
for i in range(max_channels): |
|
shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i] |
|
return shifted_input_ids |
|
|
|
|
|
def rpadding(input_ids, channels, tokenizer): |
|
attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids] |
|
max_length = max(ids.shape[0] for ids in input_ids) |
|
padded_input_ids, padded_attns = [], [] |
|
|
|
for ids, attn in zip(input_ids, attention_masks): |
|
pad_len = max_length - ids.shape[0] |
|
input_pad = np.full((pad_len, channels), 1024) |
|
input_pad[:, 0] = tokenizer.pad_token_id |
|
padded_input_ids.append(np.concatenate([input_pad, ids])) |
|
attn_pad = np.zeros(pad_len) |
|
padded_attns.append(np.concatenate([attn_pad, attn])) |
|
|
|
input_ids = torch.tensor(np.stack(padded_input_ids)) |
|
attention_mask = torch.tensor(np.stack(padded_attns)) |
|
|
|
return input_ids, attention_mask |
|
|
|
|
|
def find_max_valid_positions(C: torch.Tensor, invalid_value=1024) -> torch.Tensor: |
|
values = C[:, :, 1] |
|
mask = (values != invalid_value) |
|
reversed_mask = mask.flip(dims=[1]) |
|
reversed_indices = torch.argmax(reversed_mask.int(), dim=1) |
|
seq_len = C.size(1) |
|
original_indices = seq_len - 1 - reversed_indices |
|
has_valid = mask.any(dim=1) |
|
original_indices = torch.where(has_valid, original_indices, -1) |
|
return original_indices |
|
|
|
|
|
def normalize_text(text: str) -> str: |
|
""" |
|
Normalize multi-speaker script. |
|
|
|
1. Don't preserve line breaks. |
|
2. Remove brackets for non-speaker tags (if [] doesn't contain S1/S2...Sx format, remove the brackets themselves). |
|
3. Remove decorative symbols: 【】《》()『』「」"-“” . |
|
4. Internal punctuation !;:、 → ,;only allow ? and ,。 |
|
5. Multiple 。 keep only the last one, others → ,。 |
|
6. Replace consecutive "哈" (>=2) with "(笑)". |
|
7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment. |
|
""" |
|
|
|
text = re.sub(r'\[(\d+)\]', r'[S\1]', text) |
|
|
|
|
|
remove_chars = "【】《》()『』「」""\"-“”" |
|
|
|
|
|
|
|
text = re.sub(r'\[(?!S\d+\])([^\]]*)\]', r'\1', text) |
|
|
|
|
|
segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " ")) |
|
normalized_lines = [] |
|
|
|
for seg in segments: |
|
seg = seg.strip() |
|
if not seg: |
|
continue |
|
|
|
|
|
m = re.match(r'^(\[S\d+\])\s*(.*)', seg) |
|
tag, content = m.groups() if m else ('', seg) |
|
|
|
|
|
content = re.sub(f"[{re.escape(remove_chars)}]", "", content) |
|
|
|
|
|
content = re.sub(r'哈{2,}', '(笑)', content) |
|
|
|
|
|
content = content.replace('——', ',') |
|
content = content.replace('……', ',') |
|
|
|
|
|
internal_punct_map = str.maketrans({ |
|
'!': ',', '!': ',', |
|
';': ',', ';': ',', |
|
':': ',', ':': ',', |
|
'、': ',', |
|
'?': ',', '?': ',' |
|
}) |
|
content = content.translate(internal_punct_map) |
|
content = content.strip() |
|
|
|
|
|
if len(content) > 1: |
|
last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) |
|
body = content[:-1].replace('。', ',') |
|
content = body + last_ch |
|
|
|
normalized_lines.append(f"{tag}{content}".strip()) |
|
|
|
return "".join(normalized_lines) |
|
|
|
|
|
def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, start_idx, use_normalize=False): |
|
"""Process a batch of data items and generate audio, return audio data and metadata""" |
|
try: |
|
|
|
batch_size = len(batch_items) |
|
texts = [] |
|
prompts = [system_prompt] * batch_size |
|
prompt_audios = [] |
|
actual_texts_data = [] |
|
|
|
print(f"Processing {batch_size} samples starting from index {start_idx}...") |
|
|
|
|
|
for i, item in enumerate(batch_items): |
|
|
|
processed_item = process_jsonl_item(item) |
|
|
|
text = processed_item["text"] |
|
prompt_text = processed_item["prompt_text"] |
|
|
|
|
|
full_text = prompt_text + text |
|
original_full_text = full_text |
|
|
|
|
|
if use_normalize: |
|
full_text = normalize_text(full_text) |
|
|
|
|
|
final_text = full_text.replace("[S1]", "<speaker1>").replace("[S2]", "<speaker2>") |
|
texts.append(final_text) |
|
|
|
|
|
actual_texts_data.append({ |
|
"index": start_idx + i, |
|
"original_text": original_full_text, |
|
"normalized_text": normalize_text(original_full_text) if use_normalize else None, |
|
"final_text": final_text, |
|
"use_normalize": use_normalize |
|
}) |
|
|
|
|
|
prompt_audios.append(processed_item["prompt_audio"]) |
|
|
|
|
|
input_ids_list = [] |
|
for i, (text, prompt, audio_path) in enumerate(zip(texts, prompts, prompt_audios)): |
|
|
|
audio_data = load_audio_data(audio_path) if audio_path else None |
|
inputs = process_inputs(tokenizer, spt, prompt, text, device, audio_data) |
|
inputs = shifting_inputs(inputs, tokenizer) |
|
input_ids_list.append(inputs) |
|
|
|
|
|
input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer) |
|
|
|
|
|
print(f"Starting batch audio generation...") |
|
start = input_ids.shape[1] - MAX_CHANNELS + 1 |
|
|
|
|
|
input_ids = input_ids.to(device) |
|
attention_mask = attention_mask.to(device) |
|
|
|
|
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
) |
|
print(f"Original outputs shape: {outputs.shape}") |
|
print(f"Start value: {start}") |
|
print(f"Shape after slicing: {outputs[:, start:].shape}") |
|
print(f"MAX_CHANNELS: {MAX_CHANNELS}") |
|
print(f"Calculated seq_len: {outputs.shape[1] - MAX_CHANNELS + 1}") |
|
|
|
outputs = outputs[:, start:] |
|
seq_len = outputs.shape[1] - MAX_CHANNELS + 1 |
|
speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device) |
|
|
|
|
|
|
|
for j in range(MAX_CHANNELS): |
|
speech_ids[..., j] = outputs[:, j : seq_len + j, j] |
|
if j == 0: |
|
speech_ids[..., j] = speech_ids[..., j] - 151665 |
|
|
|
|
|
li = find_max_valid_positions(speech_ids) |
|
|
|
|
|
audio_results = [] |
|
|
|
|
|
for i in range(batch_size): |
|
try: |
|
|
|
end_idx = li[i] + 1 |
|
if end_idx <= 0: |
|
print(f"Sample {start_idx + i} has no valid speech tokens") |
|
audio_results.append(None) |
|
continue |
|
|
|
this_speech_id = speech_ids[i, :end_idx] |
|
print(f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}") |
|
|
|
|
|
with torch.no_grad(): |
|
codes_list = [this_speech_id.permute(1, 0)] |
|
decode_result = spt.decode(codes_list, overlap_seconds=10) |
|
audio_result = decode_result["syn_wav_list"][0].cpu().detach() |
|
|
|
if audio_result.ndim == 1: |
|
audio_result = audio_result.unsqueeze(0) |
|
|
|
|
|
audio_results.append({ |
|
"audio_data": audio_result, |
|
"sample_rate": spt.output_sample_rate, |
|
"index": start_idx + i |
|
}) |
|
print(f"Audio generation completed: sample {start_idx + i}") |
|
|
|
except Exception as e: |
|
print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...") |
|
import traceback |
|
traceback.print_exc() |
|
audio_results.append(None) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return actual_texts_data, audio_results |
|
|
|
except Exception as e: |
|
print(f"Error during batch processing: {str(e)}") |
|
raise |
|
|