File size: 12,420 Bytes
8062b24 c609965 a99feaf c609965 8062b24 031c41b 8062b24 94dbb00 c609965 8062b24 c609965 031c41b c609965 031c41b c609965 a99feaf c609965 a99feaf c609965 8dbeb08 c609965 a99feaf c609965 d64c54a 8dbeb08 c609965 8dbeb08 c609965 031c41b 8dbeb08 c609965 8dbeb08 c609965 8dbeb08 c609965 8dbeb08 c609965 8dbeb08 c609965 031c41b c609965 031c41b c609965 d64c54a c609965 031c41b c609965 031c41b c609965 031c41b c609965 031c41b c609965 031c41b c609965 8dbeb08 c609965 031c41b c609965 031c41b c609965 031c41b c609965 8062b24 c609965 a99feaf 5621891 a99feaf 5621891 a99feaf 8dbeb08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
import os
import torch
import numpy as np
import librosa
import soundfile as sf
import traceback
import base64
import io
import wave
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC
from vllm import LLM, SamplingParams
class EndpointHandler:
def __init__(self, path=""):
# Delimiter tokens as defined in Orpheus' vocabulary
self.START_OF_HUMAN = 128259
self.START_OF_TEXT = 128000
self.END_OF_TEXT = 128009
self.END_OF_HUMAN = 128260
self.START_OF_AI = 128261
self.START_OF_SPEECH = 128257
self.END_OF_SPEECH = 128258
self.END_OF_AI = 128262
self.AUDIO_TOKENS_START = 128266
# Load the models and tokenizer
self.model = LLM(path, max_model_len = 4096, gpu_memory_utilization = 0.3)
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Move to devices
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.model.to(self.device)
# Load SNAC model for audio decoding
try:
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
self.snac_model.to(self.device)
except Exception as e:
raise RuntimeError(f"Failed to load SNAC model: {e}")
# Set up functions to format and encode text/audio
def encode_text(self, text):
return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
def encode_audio(self, base64_audio_str):
audio_bytes = base64.b64decode(base64_audio_str)
audio_buffer = io.BytesIO(audio_bytes)
waveform, sr = sf.read(audio_buffer, dtype='float32')
if waveform.ndim > 1:
waveform = np.mean(waveform, axis=1)
if sr != 24000:
waveform = librosa.resample(waveform, orig_sr=sr, target_sr=24000)
return self.tokenize_audio(waveform)
def format_text_block(self, text_ids):
return [
torch.tensor([[self.START_OF_HUMAN]], dtype=torch.int64),
torch.tensor([[self.START_OF_TEXT]], dtype=torch.int64),
text_ids,
torch.tensor([[self.END_OF_TEXT]], dtype=torch.int64),
torch.tensor([[self.END_OF_HUMAN]], dtype=torch.int64)
]
def format_audio_block(self, audio_codes):
return [
torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64),
torch.tensor([audio_codes], dtype=torch.int64),
torch.tensor([[self.END_OF_SPEECH]], dtype=torch.int64),
torch.tensor([[self.END_OF_AI]], dtype=torch.int64)
]
def enroll_user(self, enrollment_pairs):
"""
Parameters:
- enrollment_pairs: List of tuples (text, audio_data), where audio_data is
base64-encoded audio data
Returns:
- cloning_features (str): serialized enrollment data
"""
enrollment_data = []
for text, base64_audio in enrollment_pairs:
text_ids = self.encode_text(text).cpu()
audio_codes = self.encode_audio(base64_audio)
enrollment_data.append({
"text_ids": text_ids,
"audio_codes": audio_codes
})
# Serialize enrollment data
buffer = io.BytesIO()
torch.save(enrollment_data, buffer)
buffer.seek(0)
# Encode as base64 string and assign to attribute
cloning_features = base64.b64encode(buffer.read()).decode('utf-8')
return cloning_features
def prepare_audio_tokens_for_decoder(self, audio_codes_list):
"""
Given a list containing sequences of generated audio codes, do the following:
1. Trim length to a multiple of 7 (SNAC decoder requires 7 tokens per audio frame)
2. Adjust token values to SNAC decoder's expected range
"""
modified_audio_codes_list = []
for audio_codes in audio_codes_list:
# Trim length to a multiple of 7
length = (audio_codes.size(0) // 7) * 7
trimmed = audio_codes[:length]
# Adjust token values to SNAC decoder's expected range
audio_codes = trimmed - self.AUDIO_TOKENS_START
# Add modified audio codes to list
modified_audio_codes_list.append(audio_codes)
return modified_audio_codes_list
# Convert audio sample to codes and reconstruct
def tokenize_audio(self, waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device)
with torch.inference_mode():
codes = self.snac_model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][(1 * i) + 0].item() + self.AUDIO_TOKENS_START + (0 * 4096))
all_codes.append(codes[1][0][(2 * i) + 0].item() + self.AUDIO_TOKENS_START + (1 * 4096))
all_codes.append(codes[2][0][(4 * i) + 0].item() + self.AUDIO_TOKENS_START + (2 * 4096))
all_codes.append(codes[2][0][(4 * i) + 1].item() + self.AUDIO_TOKENS_START + (3 * 4096))
all_codes.append(codes[1][0][(2 * i) + 1].item() + self.AUDIO_TOKENS_START + (4 * 4096))
all_codes.append(codes[2][0][(4 * i) + 2].item() + self.AUDIO_TOKENS_START + (5 * 4096))
all_codes.append(codes[2][0][(4 * i) + 3].item() + self.AUDIO_TOKENS_START + (6 * 4096))
return all_codes
def preprocess(self, data):
# Preprocess input data before inference
self.voice_cloning = data.get("clone", False)
# Extract parameters from request
target_text = data["inputs"]
parameters = data.get("parameters", {})
cloning_features = data.get("cloning_features", None)
temperature = float(parameters.get("temperature", 0.6))
top_p = float(parameters.get("top_p", 0.95))
max_new_tokens = int(parameters.get("max_new_tokens", 1200))
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
if self.voice_cloning:
"""Handle voice cloning using cloning features"""
if not cloning_features:
raise ValueError("No cloning features were provided")
else:
# Decode back into tensors
enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
# Process pre-tokenized enrollment_data
input_sequence = []
for item in enrollment_data:
text_ids = item["text_ids"]
audio_codes = item["audio_codes"]
input_sequence.extend(self.format_text_block(text_ids))
input_sequence.extend(self.format_audio_block(audio_codes))
# Append target text whose audio we want
target_text_ids = self.encode_text(target_text)
input_sequence.extend(self.format_text_block(target_text_ids))
# Start of target audio - audio codes to be completed by model
input_sequence.extend([
torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64)
])
# Final input tensor
input_ids = torch.cat(input_sequence, dim=1)
# Heuristic to determine max_new_tokens based on empirical relationship
# between the length of the prompt ids and the length of the generated ids
prompt_ids = self.encode_text(target_text)
max_new_tokens = int(prompt_ids.size()[1] * 20 + 200)
input_ids = input_ids.to(self.device)
else:
# Handle standard text-to-speech
# Extract parameters from request
voice = parameters.get("voice", "Eniola")
prompt = f"{voice}: {target_text}"
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Add special tokens
input_ids = torch.cat(self.format_text_block(input_ids), dim=1)
# No need for padding as we're processing a single sequence
input_ids = input_ids.to(self.device)
return {
"input_ids": input_ids,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
def inference(self, inputs):
"""
Run model inference on the preprocessed inputs
"""
# Extract parameters
input_ids = inputs["input_ids"]
sampling_params = SamplingParams(
temperature = inputs["temperature"],
top_p = inputs["top_p"],
max_tokens = inputs["max_new_tokens"],
repetition_penalty = inputs["repetition_penalty"],
stop_token_ids = [self.END_OF_SPEECH],
)
prompt_string = self.tokenizer.decode(input_ids[0])
# Forward pass through the model
generated_ids = self.model.generate(prompt_string, sampling_params)
return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0)
def __call__(self, data):
# Main entry point for the handler
try:
enroll_user = data.get("enroll_user", False)
if enroll_user:
# We extract cloning features for enrollment
enrollment_pairs = data.get("enrollments", [])
cloning_features = self.enroll_user(enrollment_pairs)
return {"cloning_features": cloning_features}
else:
# We want to generate speech using preset cloning features
preprocessed_inputs = self.preprocess(data)
model_outputs = self.inference(preprocessed_inputs)
response = self.postprocess(model_outputs)
return response
# Catch that error, baby
except Exception as e:
traceback.print_exc()
return {"error": str(e)}
# Postprocess generated ids
def convert_codes_to_waveform(self, code_list):
"""
Reorganize tokens for SNAC decoding
"""
layer_1 = [] # Coarsest layer
layer_2 = [] # Intermediate layer
layer_3 = [] # Finest layer
num_groups = len(code_list) // 7
for i in range(num_groups):
idx = 7 * i
layer_1.append(code_list[7 * i + 0] - (0 * 4096))
layer_2.append(code_list[7 * i + 1] - (1 * 4096))
layer_3.append(code_list[7 * i + 2] - (2 * 4096))
layer_3.append(code_list[7 * i + 3] - (3 * 4096))
layer_2.append(code_list[7 * i + 4] - (4 * 4096))
layer_3.append(code_list[7 * i + 5] - (5 * 4096))
layer_3.append(code_list[7 * i + 6] - (6 * 4096))
codes = [
torch.tensor(layer_1).unsqueeze(0).to(self.device),
torch.tensor(layer_2).unsqueeze(0).to(self.device),
torch.tensor(layer_3).unsqueeze(0).to(self.device)
]
# Decode audio
audio_hat = self.snac_model.decode(codes)
return audio_hat
def postprocess(self, generated_ids):
if self.voice_cloning:
"""
For cloning applications, use this postprocess function to get generated audio samples
"""
# Modify audio codes to be digestible byb SNAC decoder
code_lists = self.prepare_audio_tokens_for_decoder(generated_ids)
# Generate audio from codes
temp = self.convert_codes_to_waveform(code_lists[0])
audio_sample = temp.detach().squeeze().to("cpu").numpy()
else:
"""
Process generated tokens into audio
"""
# Find Start of Audio token
token_indices = (generated_ids == self.START_OF_SPEECH).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
cropped_tensor = generated_ids
# Remove End of Audio tokens
processed_rows = []
for row in cropped_tensor:
masked_row = row[row != self.END_OF_SPEECH]
processed_rows.append(masked_row)
code_lists = self.prepare_audio_tokens_for_decoder(processed_rows)
# Generate audio from codes
audio_samples = []
for code_list in code_lists:
if len(code_list) > 0:
audio = self.convert_codes_to_waveform(code_list)
audio_samples.append(audio)
else:
raise ValueError("Empty code list, no audio to generate")
if not audio_samples:
return {"error": "No audio samples generated"}
# Return first (and only) audio sample
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
# Convert float32 array to int16 for WAV format
audio_int16 = (audio_sample * 32767).astype(np.int16)
# Write to WAV in memory (float32 or int16 depending on your preference)
buffer = io.BytesIO()
sf.write(buffer, audio_sample, samplerate=24000, format='WAV', subtype='PCM_16') # or PCM_32
buffer.seek(0)
# Encode WAV bytes as base64
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
return {
"audio_sample": audio_sample,
"audio_b64": audio_b64,
"sample_rate": 24000,
} |