MOSS-TTSD / generation_utils.py
yhzx233's picture
refactor: re-raise Exception
2fac462
import os
import re
import torch
import torchaudio
import numpy as np
from transformers import AutoTokenizer
from modeling_asteroid import AsteroidTTSInstruct
from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
MAX_CHANNELS = 8
SILENCE_DURATION = 5.0 # Fixed silence duration: 5 seconds
def load_model(model_path, spt_config_path, spt_checkpoint_path):
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa")
spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
model.eval()
spt.eval()
return tokenizer, model, spt
def process_jsonl_item(item):
"""Process JSONL data items and extract audio and text information according to the new format"""
base_path = item.get("base_path", "")
text = item.get("text", "")
# Process prompt audio and text
if "prompt_audio" in item and "prompt_text" in item:
print("Using prompt_audio and prompt_text directly from item.")
# If prompt_audio and prompt_text exist, use them directly
prompt_audio = item["prompt_audio"]
prompt_text = item["prompt_text"]
# Only perform path joining when prompt_audio is a string path
if isinstance(prompt_audio, str) and base_path and prompt_audio:
prompt_audio = os.path.join(base_path, prompt_audio)
else:
print("Using speaker1 and speaker2 information for prompt audio and text.")
# Otherwise, merge speaker1 and speaker2 information
prompt_audio_speaker1 = item.get("prompt_audio_speaker1", "")
prompt_text_speaker1 = item.get("prompt_text_speaker1", "")
prompt_audio_speaker2 = item.get("prompt_audio_speaker2", "")
prompt_text_speaker2 = item.get("prompt_text_speaker2", "")
# Process audio: if it's a string path, perform path joining; if it's a tuple, use directly
if isinstance(prompt_audio_speaker1, str):
speaker1_audio = os.path.join(base_path, prompt_audio_speaker1) if base_path and prompt_audio_speaker1 else prompt_audio_speaker1
else:
speaker1_audio = prompt_audio_speaker1 # Use tuple directly
if isinstance(prompt_audio_speaker2, str):
speaker2_audio = os.path.join(base_path, prompt_audio_speaker2) if base_path and prompt_audio_speaker2 else prompt_audio_speaker2
else:
speaker2_audio = prompt_audio_speaker2 # Use tuple directly
prompt_audio = {
"speaker1": speaker1_audio,
"speaker2": speaker2_audio
}
# Merge text
prompt_text = ""
if prompt_text_speaker1:
prompt_text += f"[S1]{prompt_text_speaker1}"
if prompt_text_speaker2:
prompt_text += f"[S2]{prompt_text_speaker2}"
prompt_text = prompt_text.strip()
return {
"text": text,
"prompt_text": prompt_text,
"prompt_audio": prompt_audio
}
def load_audio_data(prompt_audio, target_sample_rate=16000):
"""Load audio data and return processed audio tensor
Args:
prompt_audio: Can be in the following formats:
- String: audio file path
- Tuple: (wav, sr) result from torchaudio.load
- Dict: {"speaker1": path_or_tuple, "speaker2": path_or_tuple}
"""
if prompt_audio is None:
return None
try:
# Check if prompt_audio is a dictionary (containing speaker1 and speaker2)
if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio:
# Process audio from both speakers separately
wav1, sr1 = _load_single_audio(prompt_audio["speaker1"])
wav2, sr2 = _load_single_audio(prompt_audio["speaker2"])
# Merge audio from both speakers
wav = merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate)
if wav is None:
return None
else:
# Single audio
wav, sr = _load_single_audio(prompt_audio)
# Resample to 16k
if sr != target_sample_rate:
wav = torchaudio.functional.resample(wav, sr, target_sample_rate)
# Ensure mono channel
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True) # Convert multi-channel to mono
if len(wav.shape) == 1:
wav = wav.unsqueeze(0)
return wav
except Exception as e:
print(f"Error loading audio data: {e}")
raise
def _load_single_audio(audio_input):
"""Load single audio, supports file path or (wav, sr) tuple
Args:
audio_input: String (file path) or tuple (wav, sr)
Returns:
tuple: (wav, sr)
"""
if isinstance(audio_input, tuple) and len(audio_input) == 2:
# Already a (wav, sr) tuple
wav, sr = audio_input
return wav, sr
elif isinstance(audio_input, str):
# Is a file path, needs to be loaded
wav, sr = torchaudio.load(audio_input)
return wav, sr
else:
raise ValueError(f"Unsupported audio input format: {type(audio_input)}")
def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000):
"""Merge audio data from two speakers"""
try:
# Process first audio
if sr1 != target_sample_rate:
wav1 = torchaudio.functional.resample(wav1, sr1, target_sample_rate)
# Ensure mono channel
if wav1.shape[0] > 1:
wav1 = wav1.mean(dim=0, keepdim=True) # Convert multi-channel to mono
if len(wav1.shape) == 1:
wav1 = wav1.unsqueeze(0)
# Process second audio
if sr2 != target_sample_rate:
wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate)
# Ensure mono channel
if wav2.shape[0] > 1:
wav2 = wav2.mean(dim=0, keepdim=True) # Convert multi-channel to mono
if len(wav2.shape) == 1:
wav2 = wav2.unsqueeze(0)
# Concatenate audio
merged_wav = torch.cat([wav1, wav2], dim=1)
return merged_wav
except Exception as e:
print(f"Error merging audio: {e}")
raise
def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, max_channels=8, pad_token=1024):
seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>"
inputs1 = np.array(tokenizer.encode(seq))
input_ids = np.full((inputs1.shape[0], max_channels), pad_token)
input_ids[:, 0] = inputs1
if audio_data is not None:
try:
# audio_data should now be a processed audio tensor
wav = audio_data
# Add fixed 5-second silence at the end of audio (using 16k sample rate)
silence_samples = int(SILENCE_DURATION * 16000)
silence = torch.zeros(wav.shape[0], silence_samples)
wav = torch.cat([wav, silence], dim=1)
with torch.no_grad():
# Use SPT encoding
encode_result = spt.encode([wav.squeeze().to(device)])
audio_token = encode_result["codes_list"][0].permute(1, 0).cpu().numpy() # Adjust dimension order
# similar to DAC encoding adjustment
audio_token[:, 0] = audio_token[:, 0] + 151665 # Keep this line if offset is needed, otherwise delete
input_ids = np.concatenate([input_ids, audio_token])[:-60]
except Exception as e:
print(f"Error processing audio data: {e}")
raise
return input_ids
def shifting_inputs(input_ids, tokenizer, pad_token=1024, max_channels=8):
seq_len = input_ids.shape[0]
new_seq_len = seq_len + max_channels - 1
shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64)
shifted_input_ids[:, 0] = np.full(new_seq_len, tokenizer.pad_token_id, dtype=np.int64)
for i in range(max_channels):
shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i]
return shifted_input_ids
def rpadding(input_ids, channels, tokenizer):
attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids]
max_length = max(ids.shape[0] for ids in input_ids)
padded_input_ids, padded_attns = [], []
for ids, attn in zip(input_ids, attention_masks):
pad_len = max_length - ids.shape[0]
input_pad = np.full((pad_len, channels), 1024)
input_pad[:, 0] = tokenizer.pad_token_id
padded_input_ids.append(np.concatenate([input_pad, ids]))
attn_pad = np.zeros(pad_len)
padded_attns.append(np.concatenate([attn_pad, attn]))
input_ids = torch.tensor(np.stack(padded_input_ids))
attention_mask = torch.tensor(np.stack(padded_attns))
return input_ids, attention_mask
def find_max_valid_positions(C: torch.Tensor, invalid_value=1024) -> torch.Tensor:
values = C[:, :, 1]
mask = (values != invalid_value)
reversed_mask = mask.flip(dims=[1])
reversed_indices = torch.argmax(reversed_mask.int(), dim=1)
seq_len = C.size(1)
original_indices = seq_len - 1 - reversed_indices
has_valid = mask.any(dim=1)
original_indices = torch.where(has_valid, original_indices, -1)
return original_indices
def normalize_text(text: str) -> str:
"""
Normalize multi-speaker script.
1. Don't preserve line breaks.
2. Remove brackets for non-speaker tags (if [] doesn't contain S1/S2...Sx format, remove the brackets themselves).
3. Remove decorative symbols: 【】《》()『』「」"-“” .
4. Internal punctuation !;:、 → ,;only allow ? and ,。
5. Multiple 。 keep only the last one, others → ,。
6. Replace consecutive "哈" (>=2) with "(笑)".
7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment.
"""
# Replace [1], [2] etc. format with [S1], [S2] etc. format
text = re.sub(r'\[(\d+)\]', r'[S\1]', text)
# Remove decorative characters
remove_chars = "【】《》()『』「」""\"-“”"
# Remove brackets for non-speaker tags (keep content, only remove brackets themselves)
text = re.sub(r'\[(?!S\d+\])([^\]]*)\]', r'\1', text)
# Use positive lookahead to split text by speaker tags (tags themselves are still preserved)
segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " "))
normalized_lines = []
for seg in segments:
seg = seg.strip()
if not seg:
continue
# Extract tags
m = re.match(r'^(\[S\d+\])\s*(.*)', seg)
tag, content = m.groups() if m else ('', seg)
# Remove irrelevant symbols
content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
# Handle consecutive "哈" characters: replace 2 or more with "(笑)"
content = re.sub(r'哈{2,}', '(笑)', content)
# First handle multi-character punctuation marks
content = content.replace('——', ',')
content = content.replace('……', ',')
# Handle single-character internal punctuation marks
internal_punct_map = str.maketrans({
'!': ',', '!': ',',
';': ',', ';': ',',
':': ',', ':': ',',
'、': ',',
'?': ',', '?': ','
})
content = content.translate(internal_punct_map)
content = content.strip()
# Keep only the final period
if len(content) > 1:
last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1])
body = content[:-1].replace('。', ',')
content = body + last_ch
normalized_lines.append(f"{tag}{content}".strip())
return "".join(normalized_lines)
def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, start_idx, use_normalize=False):
"""Process a batch of data items and generate audio, return audio data and metadata"""
try:
# Prepare batch data
batch_size = len(batch_items)
texts = []
prompts = [system_prompt] * batch_size
prompt_audios = []
actual_texts_data = [] # Store actual text data used
print(f"Processing {batch_size} samples starting from index {start_idx}...")
# Extract text and audio from each sample
for i, item in enumerate(batch_items):
# Use new processing function
processed_item = process_jsonl_item(item)
text = processed_item["text"]
prompt_text = processed_item["prompt_text"]
# Merge text
full_text = prompt_text + text
original_full_text = full_text # Save original text
# Apply text normalization based on parameter
if use_normalize:
full_text = normalize_text(full_text)
# Replace speaker tags
final_text = full_text.replace("[S1]", "<speaker1>").replace("[S2]", "<speaker2>")
texts.append(final_text)
# Save actual text information used
actual_texts_data.append({
"index": start_idx + i,
"original_text": original_full_text,
"normalized_text": normalize_text(original_full_text) if use_normalize else None,
"final_text": final_text,
"use_normalize": use_normalize
})
# Get reference audio
prompt_audios.append(processed_item["prompt_audio"])
# Process inputs
input_ids_list = []
for i, (text, prompt, audio_path) in enumerate(zip(texts, prompts, prompt_audios)):
# Load audio data here
audio_data = load_audio_data(audio_path) if audio_path else None
inputs = process_inputs(tokenizer, spt, prompt, text, device, audio_data)
inputs = shifting_inputs(inputs, tokenizer)
input_ids_list.append(inputs)
# Pad batch inputs
input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer)
# Batch generation
print(f"Starting batch audio generation...")
start = input_ids.shape[1] - MAX_CHANNELS + 1
# Move inputs to GPU
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
# Generate model outputs
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
)
print(f"Original outputs shape: {outputs.shape}")
print(f"Start value: {start}")
print(f"Shape after slicing: {outputs[:, start:].shape}")
print(f"MAX_CHANNELS: {MAX_CHANNELS}")
print(f"Calculated seq_len: {outputs.shape[1] - MAX_CHANNELS + 1}")
# Process outputs
outputs = outputs[:, start:]
seq_len = outputs.shape[1] - MAX_CHANNELS + 1
speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device)
# Adjust output format
for j in range(MAX_CHANNELS):
speech_ids[..., j] = outputs[:, j : seq_len + j, j]
if j == 0:
speech_ids[..., j] = speech_ids[..., j] - 151665
# Find valid positions for each sample
li = find_max_valid_positions(speech_ids)
# Store audio result data
audio_results = []
# Process batch sample results individually
for i in range(batch_size):
try:
# Extract valid speech tokens
end_idx = li[i] + 1
if end_idx <= 0:
print(f"Sample {start_idx + i} has no valid speech tokens")
audio_results.append(None)
continue
this_speech_id = speech_ids[i, :end_idx]
print(f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}")
# Decode generated audio
with torch.no_grad():
codes_list = [this_speech_id.permute(1, 0)] # Convert to SPT expected format
decode_result = spt.decode(codes_list, overlap_seconds=10)
audio_result = decode_result["syn_wav_list"][0].cpu().detach()
if audio_result.ndim == 1: # If 1D [samples]
audio_result = audio_result.unsqueeze(0) # Convert to 2D [1, samples]
# Save audio data instead of file path
audio_results.append({
"audio_data": audio_result,
"sample_rate": spt.output_sample_rate,
"index": start_idx + i
})
print(f"Audio generation completed: sample {start_idx + i}")
except Exception as e:
print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...")
import traceback
traceback.print_exc()
audio_results.append(None)
# Clean up GPU memory
torch.cuda.empty_cache()
# Return text data and audio data
return actual_texts_data, audio_results
except Exception as e:
print(f"Error during batch processing: {str(e)}")
raise