import os import re import torch import torchaudio import numpy as np from transformers import AutoTokenizer from modeling_asteroid import AsteroidTTSInstruct from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer MAX_CHANNELS = 8 SILENCE_DURATION = 5.0 # Fixed silence duration: 5 seconds def load_model(model_path, spt_config_path, spt_checkpoint_path): tokenizer = AutoTokenizer.from_pretrained(model_path) model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa") spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path) model.eval() spt.eval() return tokenizer, model, spt def process_jsonl_item(item): """Process JSONL data items and extract audio and text information according to the new format""" base_path = item.get("base_path", "") text = item.get("text", "") # Process prompt audio and text if "prompt_audio" in item and "prompt_text" in item: print("Using prompt_audio and prompt_text directly from item.") # If prompt_audio and prompt_text exist, use them directly prompt_audio = item["prompt_audio"] prompt_text = item["prompt_text"] # Only perform path joining when prompt_audio is a string path if isinstance(prompt_audio, str) and base_path and prompt_audio: prompt_audio = os.path.join(base_path, prompt_audio) else: print("Using speaker1 and speaker2 information for prompt audio and text.") # Otherwise, merge speaker1 and speaker2 information prompt_audio_speaker1 = item.get("prompt_audio_speaker1", "") prompt_text_speaker1 = item.get("prompt_text_speaker1", "") prompt_audio_speaker2 = item.get("prompt_audio_speaker2", "") prompt_text_speaker2 = item.get("prompt_text_speaker2", "") # Process audio: if it's a string path, perform path joining; if it's a tuple, use directly if isinstance(prompt_audio_speaker1, str): speaker1_audio = os.path.join(base_path, prompt_audio_speaker1) if base_path and prompt_audio_speaker1 else prompt_audio_speaker1 else: speaker1_audio = prompt_audio_speaker1 # Use tuple directly if isinstance(prompt_audio_speaker2, str): speaker2_audio = os.path.join(base_path, prompt_audio_speaker2) if base_path and prompt_audio_speaker2 else prompt_audio_speaker2 else: speaker2_audio = prompt_audio_speaker2 # Use tuple directly prompt_audio = { "speaker1": speaker1_audio, "speaker2": speaker2_audio } # Merge text prompt_text = "" if prompt_text_speaker1: prompt_text += f"[S1]{prompt_text_speaker1}" if prompt_text_speaker2: prompt_text += f"[S2]{prompt_text_speaker2}" prompt_text = prompt_text.strip() return { "text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio } def load_audio_data(prompt_audio, target_sample_rate=16000): """Load audio data and return processed audio tensor Args: prompt_audio: Can be in the following formats: - String: audio file path - Tuple: (wav, sr) result from torchaudio.load - Dict: {"speaker1": path_or_tuple, "speaker2": path_or_tuple} """ if prompt_audio is None: return None try: # Check if prompt_audio is a dictionary (containing speaker1 and speaker2) if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio: # Process audio from both speakers separately wav1, sr1 = _load_single_audio(prompt_audio["speaker1"]) wav2, sr2 = _load_single_audio(prompt_audio["speaker2"]) # Merge audio from both speakers wav = merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate) if wav is None: return None else: # Single audio wav, sr = _load_single_audio(prompt_audio) # Resample to 16k if sr != target_sample_rate: wav = torchaudio.functional.resample(wav, sr, target_sample_rate) # Ensure mono channel if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) # Convert multi-channel to mono if len(wav.shape) == 1: wav = wav.unsqueeze(0) return wav except Exception as e: print(f"Error loading audio data: {e}") raise def _load_single_audio(audio_input): """Load single audio, supports file path or (wav, sr) tuple Args: audio_input: String (file path) or tuple (wav, sr) Returns: tuple: (wav, sr) """ if isinstance(audio_input, tuple) and len(audio_input) == 2: # Already a (wav, sr) tuple wav, sr = audio_input return wav, sr elif isinstance(audio_input, str): # Is a file path, needs to be loaded wav, sr = torchaudio.load(audio_input) return wav, sr else: raise ValueError(f"Unsupported audio input format: {type(audio_input)}") def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000): """Merge audio data from two speakers""" try: # Process first audio if sr1 != target_sample_rate: wav1 = torchaudio.functional.resample(wav1, sr1, target_sample_rate) # Ensure mono channel if wav1.shape[0] > 1: wav1 = wav1.mean(dim=0, keepdim=True) # Convert multi-channel to mono if len(wav1.shape) == 1: wav1 = wav1.unsqueeze(0) # Process second audio if sr2 != target_sample_rate: wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate) # Ensure mono channel if wav2.shape[0] > 1: wav2 = wav2.mean(dim=0, keepdim=True) # Convert multi-channel to mono if len(wav2.shape) == 1: wav2 = wav2.unsqueeze(0) # Concatenate audio merged_wav = torch.cat([wav1, wav2], dim=1) return merged_wav except Exception as e: print(f"Error merging audio: {e}") raise def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, max_channels=8, pad_token=1024): seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>" inputs1 = np.array(tokenizer.encode(seq)) input_ids = np.full((inputs1.shape[0], max_channels), pad_token) input_ids[:, 0] = inputs1 if audio_data is not None: try: # audio_data should now be a processed audio tensor wav = audio_data # Add fixed 5-second silence at the end of audio (using 16k sample rate) silence_samples = int(SILENCE_DURATION * 16000) silence = torch.zeros(wav.shape[0], silence_samples) wav = torch.cat([wav, silence], dim=1) with torch.no_grad(): # Use SPT encoding encode_result = spt.encode([wav.squeeze().to(device)]) audio_token = encode_result["codes_list"][0].permute(1, 0).cpu().numpy() # Adjust dimension order # similar to DAC encoding adjustment audio_token[:, 0] = audio_token[:, 0] + 151665 # Keep this line if offset is needed, otherwise delete input_ids = np.concatenate([input_ids, audio_token])[:-60] except Exception as e: print(f"Error processing audio data: {e}") raise return input_ids def shifting_inputs(input_ids, tokenizer, pad_token=1024, max_channels=8): seq_len = input_ids.shape[0] new_seq_len = seq_len + max_channels - 1 shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64) shifted_input_ids[:, 0] = np.full(new_seq_len, tokenizer.pad_token_id, dtype=np.int64) for i in range(max_channels): shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i] return shifted_input_ids def rpadding(input_ids, channels, tokenizer): attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids] max_length = max(ids.shape[0] for ids in input_ids) padded_input_ids, padded_attns = [], [] for ids, attn in zip(input_ids, attention_masks): pad_len = max_length - ids.shape[0] input_pad = np.full((pad_len, channels), 1024) input_pad[:, 0] = tokenizer.pad_token_id padded_input_ids.append(np.concatenate([input_pad, ids])) attn_pad = np.zeros(pad_len) padded_attns.append(np.concatenate([attn_pad, attn])) input_ids = torch.tensor(np.stack(padded_input_ids)) attention_mask = torch.tensor(np.stack(padded_attns)) return input_ids, attention_mask def find_max_valid_positions(C: torch.Tensor, invalid_value=1024) -> torch.Tensor: values = C[:, :, 1] mask = (values != invalid_value) reversed_mask = mask.flip(dims=[1]) reversed_indices = torch.argmax(reversed_mask.int(), dim=1) seq_len = C.size(1) original_indices = seq_len - 1 - reversed_indices has_valid = mask.any(dim=1) original_indices = torch.where(has_valid, original_indices, -1) return original_indices def normalize_text(text: str) -> str: """ Normalize multi-speaker script. 1. Don't preserve line breaks. 2. Remove brackets for non-speaker tags (if [] doesn't contain S1/S2...Sx format, remove the brackets themselves). 3. Remove decorative symbols: 【】《》()『』「」"-“” . 4. Internal punctuation !;:、 → ,;only allow ? and ,。 5. Multiple 。 keep only the last one, others → ,。 6. Replace consecutive "哈" (>=2) with "(笑)". 7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment. """ # Replace [1], [2] etc. format with [S1], [S2] etc. format text = re.sub(r'\[(\d+)\]', r'[S\1]', text) # Remove decorative characters remove_chars = "【】《》()『』「」""\"-“”" # Remove brackets for non-speaker tags (keep content, only remove brackets themselves) text = re.sub(r'\[(?!S\d+\])([^\]]*)\]', r'\1', text) # Use positive lookahead to split text by speaker tags (tags themselves are still preserved) segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " ")) normalized_lines = [] for seg in segments: seg = seg.strip() if not seg: continue # Extract tags m = re.match(r'^(\[S\d+\])\s*(.*)', seg) tag, content = m.groups() if m else ('', seg) # Remove irrelevant symbols content = re.sub(f"[{re.escape(remove_chars)}]", "", content) # Handle consecutive "哈" characters: replace 2 or more with "(笑)" content = re.sub(r'哈{2,}', '(笑)', content) # First handle multi-character punctuation marks content = content.replace('——', ',') content = content.replace('……', ',') # Handle single-character internal punctuation marks internal_punct_map = str.maketrans({ '!': ',', '!': ',', ';': ',', ';': ',', ':': ',', ':': ',', '、': ',', '?': ',', '?': ',' }) content = content.translate(internal_punct_map) content = content.strip() # Keep only the final period if len(content) > 1: last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) body = content[:-1].replace('。', ',') content = body + last_ch normalized_lines.append(f"{tag}{content}".strip()) return "".join(normalized_lines) def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, start_idx, use_normalize=False): """Process a batch of data items and generate audio, return audio data and metadata""" try: # Prepare batch data batch_size = len(batch_items) texts = [] prompts = [system_prompt] * batch_size prompt_audios = [] actual_texts_data = [] # Store actual text data used print(f"Processing {batch_size} samples starting from index {start_idx}...") # Extract text and audio from each sample for i, item in enumerate(batch_items): # Use new processing function processed_item = process_jsonl_item(item) text = processed_item["text"] prompt_text = processed_item["prompt_text"] # Merge text full_text = prompt_text + text original_full_text = full_text # Save original text # Apply text normalization based on parameter if use_normalize: full_text = normalize_text(full_text) # Replace speaker tags final_text = full_text.replace("[S1]", "").replace("[S2]", "") texts.append(final_text) # Save actual text information used actual_texts_data.append({ "index": start_idx + i, "original_text": original_full_text, "normalized_text": normalize_text(original_full_text) if use_normalize else None, "final_text": final_text, "use_normalize": use_normalize }) # Get reference audio prompt_audios.append(processed_item["prompt_audio"]) # Process inputs input_ids_list = [] for i, (text, prompt, audio_path) in enumerate(zip(texts, prompts, prompt_audios)): # Load audio data here audio_data = load_audio_data(audio_path) if audio_path else None inputs = process_inputs(tokenizer, spt, prompt, text, device, audio_data) inputs = shifting_inputs(inputs, tokenizer) input_ids_list.append(inputs) # Pad batch inputs input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer) # Batch generation print(f"Starting batch audio generation...") start = input_ids.shape[1] - MAX_CHANNELS + 1 # Move inputs to GPU input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) # Generate model outputs outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, ) print(f"Original outputs shape: {outputs.shape}") print(f"Start value: {start}") print(f"Shape after slicing: {outputs[:, start:].shape}") print(f"MAX_CHANNELS: {MAX_CHANNELS}") print(f"Calculated seq_len: {outputs.shape[1] - MAX_CHANNELS + 1}") # Process outputs outputs = outputs[:, start:] seq_len = outputs.shape[1] - MAX_CHANNELS + 1 speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device) # Adjust output format for j in range(MAX_CHANNELS): speech_ids[..., j] = outputs[:, j : seq_len + j, j] if j == 0: speech_ids[..., j] = speech_ids[..., j] - 151665 # Find valid positions for each sample li = find_max_valid_positions(speech_ids) # Store audio result data audio_results = [] # Process batch sample results individually for i in range(batch_size): try: # Extract valid speech tokens end_idx = li[i] + 1 if end_idx <= 0: print(f"Sample {start_idx + i} has no valid speech tokens") audio_results.append(None) continue this_speech_id = speech_ids[i, :end_idx] print(f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}") # Decode generated audio with torch.no_grad(): codes_list = [this_speech_id.permute(1, 0)] # Convert to SPT expected format decode_result = spt.decode(codes_list, overlap_seconds=10) audio_result = decode_result["syn_wav_list"][0].cpu().detach() if audio_result.ndim == 1: # If 1D [samples] audio_result = audio_result.unsqueeze(0) # Convert to 2D [1, samples] # Save audio data instead of file path audio_results.append({ "audio_data": audio_result, "sample_rate": spt.output_sample_rate, "index": start_idx + i }) print(f"Audio generation completed: sample {start_idx + i}") except Exception as e: print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...") import traceback traceback.print_exc() audio_results.append(None) # Clean up GPU memory torch.cuda.empty_cache() # Return text data and audio data return actual_texts_data, audio_results except Exception as e: print(f"Error during batch processing: {str(e)}") raise