|
import os |
|
import torch |
|
import numpy as np |
|
import librosa |
|
import soundfile as sf |
|
import traceback |
|
import base64 |
|
import io |
|
import wave |
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from snac import SNAC |
|
from vllm import LLM, SamplingParams |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
|
|
self.START_OF_HUMAN = 128259 |
|
self.START_OF_TEXT = 128000 |
|
self.END_OF_TEXT = 128009 |
|
self.END_OF_HUMAN = 128260 |
|
self.START_OF_AI = 128261 |
|
self.START_OF_SPEECH = 128257 |
|
self.END_OF_SPEECH = 128258 |
|
self.END_OF_AI = 128262 |
|
self.AUDIO_TOKENS_START = 128266 |
|
|
|
|
|
self.model = LLM(path, max_model_len = 4096, gpu_memory_utilization = 0.3) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
try: |
|
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
self.snac_model.to(self.device) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load SNAC model: {e}") |
|
|
|
|
|
def encode_text(self, text): |
|
return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False) |
|
|
|
def encode_audio(self, base64_audio_str): |
|
audio_bytes = base64.b64decode(base64_audio_str) |
|
audio_buffer = io.BytesIO(audio_bytes) |
|
waveform, sr = sf.read(audio_buffer, dtype='float32') |
|
|
|
if waveform.ndim > 1: |
|
waveform = np.mean(waveform, axis=1) |
|
if sr != 24000: |
|
waveform = librosa.resample(waveform, orig_sr=sr, target_sr=24000) |
|
return self.tokenize_audio(waveform) |
|
|
|
def format_text_block(self, text_ids): |
|
return [ |
|
torch.tensor([[self.START_OF_HUMAN]], dtype=torch.int64), |
|
torch.tensor([[self.START_OF_TEXT]], dtype=torch.int64), |
|
text_ids, |
|
torch.tensor([[self.END_OF_TEXT]], dtype=torch.int64), |
|
torch.tensor([[self.END_OF_HUMAN]], dtype=torch.int64) |
|
] |
|
|
|
def format_audio_block(self, audio_codes): |
|
return [ |
|
torch.tensor([[self.START_OF_AI]], dtype=torch.int64), |
|
torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64), |
|
torch.tensor([audio_codes], dtype=torch.int64), |
|
torch.tensor([[self.END_OF_SPEECH]], dtype=torch.int64), |
|
torch.tensor([[self.END_OF_AI]], dtype=torch.int64) |
|
] |
|
|
|
def enroll_user(self, enrollment_pairs): |
|
""" |
|
Parameters: |
|
- enrollment_pairs: List of tuples (text, audio_data), where audio_data is |
|
base64-encoded audio data |
|
|
|
Returns: |
|
- cloning_features (str): serialized enrollment data |
|
""" |
|
enrollment_data = [] |
|
|
|
for text, base64_audio in enrollment_pairs: |
|
text_ids = self.encode_text(text).cpu() |
|
audio_codes = self.encode_audio(base64_audio) |
|
enrollment_data.append({ |
|
"text_ids": text_ids, |
|
"audio_codes": audio_codes |
|
}) |
|
|
|
|
|
buffer = io.BytesIO() |
|
torch.save(enrollment_data, buffer) |
|
buffer.seek(0) |
|
|
|
|
|
cloning_features = base64.b64encode(buffer.read()).decode('utf-8') |
|
return cloning_features |
|
|
|
def prepare_audio_tokens_for_decoder(self, audio_codes_list): |
|
""" |
|
Given a list containing sequences of generated audio codes, do the following: |
|
1. Trim length to a multiple of 7 (SNAC decoder requires 7 tokens per audio frame) |
|
2. Adjust token values to SNAC decoder's expected range |
|
""" |
|
modified_audio_codes_list = [] |
|
for audio_codes in audio_codes_list: |
|
|
|
|
|
length = (audio_codes.size(0) // 7) * 7 |
|
trimmed = audio_codes[:length] |
|
|
|
|
|
audio_codes = trimmed - self.AUDIO_TOKENS_START |
|
|
|
|
|
modified_audio_codes_list.append(audio_codes) |
|
|
|
return modified_audio_codes_list |
|
|
|
|
|
def tokenize_audio(self, waveform): |
|
waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device) |
|
|
|
with torch.inference_mode(): |
|
codes = self.snac_model.encode(waveform) |
|
|
|
all_codes = [] |
|
for i in range(codes[0].shape[1]): |
|
|
|
all_codes.append(codes[0][0][(1 * i) + 0].item() + self.AUDIO_TOKENS_START + (0 * 4096)) |
|
all_codes.append(codes[1][0][(2 * i) + 0].item() + self.AUDIO_TOKENS_START + (1 * 4096)) |
|
all_codes.append(codes[2][0][(4 * i) + 0].item() + self.AUDIO_TOKENS_START + (2 * 4096)) |
|
all_codes.append(codes[2][0][(4 * i) + 1].item() + self.AUDIO_TOKENS_START + (3 * 4096)) |
|
all_codes.append(codes[1][0][(2 * i) + 1].item() + self.AUDIO_TOKENS_START + (4 * 4096)) |
|
all_codes.append(codes[2][0][(4 * i) + 2].item() + self.AUDIO_TOKENS_START + (5 * 4096)) |
|
all_codes.append(codes[2][0][(4 * i) + 3].item() + self.AUDIO_TOKENS_START + (6 * 4096)) |
|
|
|
return all_codes |
|
|
|
def preprocess(self, data): |
|
|
|
|
|
|
|
self.voice_cloning = data.get("clone", False) |
|
|
|
|
|
target_text = data["inputs"] |
|
parameters = data.get("parameters", {}) |
|
cloning_features = data.get("cloning_features", None) |
|
|
|
temperature = float(parameters.get("temperature", 0.6)) |
|
top_p = float(parameters.get("top_p", 0.95)) |
|
max_new_tokens = int(parameters.get("max_new_tokens", 1200)) |
|
repetition_penalty = float(parameters.get("repetition_penalty", 1.1)) |
|
|
|
if self.voice_cloning: |
|
"""Handle voice cloning using cloning features""" |
|
|
|
if not cloning_features: |
|
raise ValueError("No cloning features were provided") |
|
else: |
|
|
|
enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features))) |
|
|
|
|
|
input_sequence = [] |
|
for item in enrollment_data: |
|
text_ids = item["text_ids"] |
|
audio_codes = item["audio_codes"] |
|
input_sequence.extend(self.format_text_block(text_ids)) |
|
input_sequence.extend(self.format_audio_block(audio_codes)) |
|
|
|
|
|
target_text_ids = self.encode_text(target_text) |
|
input_sequence.extend(self.format_text_block(target_text_ids)) |
|
|
|
|
|
input_sequence.extend([ |
|
torch.tensor([[self.START_OF_AI]], dtype=torch.int64), |
|
torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64) |
|
]) |
|
|
|
|
|
input_ids = torch.cat(input_sequence, dim=1) |
|
|
|
|
|
|
|
prompt_ids = self.encode_text(target_text) |
|
max_new_tokens = int(prompt_ids.size()[1] * 20 + 200) |
|
|
|
input_ids = input_ids.to(self.device) |
|
|
|
else: |
|
|
|
|
|
|
|
voice = parameters.get("voice", "Eniola") |
|
prompt = f"{voice}: {target_text}" |
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
|
|
input_ids = torch.cat(self.format_text_block(input_ids), dim=1) |
|
|
|
|
|
input_ids = input_ids.to(self.device) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
} |
|
|
|
def inference(self, inputs): |
|
""" |
|
Run model inference on the preprocessed inputs |
|
""" |
|
|
|
input_ids = inputs["input_ids"] |
|
|
|
sampling_params = SamplingParams( |
|
temperature = inputs["temperature"], |
|
top_p = inputs["top_p"], |
|
max_tokens = inputs["max_new_tokens"], |
|
repetition_penalty = inputs["repetition_penalty"], |
|
stop_token_ids = [self.END_OF_SPEECH], |
|
) |
|
|
|
prompt_string = self.tokenizer.decode(input_ids[0]) |
|
|
|
|
|
generated_ids = self.model.generate(prompt_string, sampling_params) |
|
|
|
return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0) |
|
|
|
def __call__(self, data): |
|
|
|
|
|
|
|
try: |
|
enroll_user = data.get("enroll_user", False) |
|
|
|
if enroll_user: |
|
|
|
enrollment_pairs = data.get("enrollments", []) |
|
cloning_features = self.enroll_user(enrollment_pairs) |
|
return {"cloning_features": cloning_features} |
|
else: |
|
|
|
preprocessed_inputs = self.preprocess(data) |
|
model_outputs = self.inference(preprocessed_inputs) |
|
response = self.postprocess(model_outputs) |
|
return response |
|
|
|
|
|
except Exception as e: |
|
traceback.print_exc() |
|
return {"error": str(e)} |
|
|
|
|
|
def convert_codes_to_waveform(self, code_list): |
|
""" |
|
Reorganize tokens for SNAC decoding |
|
""" |
|
layer_1 = [] |
|
layer_2 = [] |
|
layer_3 = [] |
|
|
|
num_groups = len(code_list) // 7 |
|
for i in range(num_groups): |
|
idx = 7 * i |
|
layer_1.append(code_list[7 * i + 0] - (0 * 4096)) |
|
layer_2.append(code_list[7 * i + 1] - (1 * 4096)) |
|
layer_3.append(code_list[7 * i + 2] - (2 * 4096)) |
|
layer_3.append(code_list[7 * i + 3] - (3 * 4096)) |
|
layer_2.append(code_list[7 * i + 4] - (4 * 4096)) |
|
layer_3.append(code_list[7 * i + 5] - (5 * 4096)) |
|
layer_3.append(code_list[7 * i + 6] - (6 * 4096)) |
|
|
|
codes = [ |
|
torch.tensor(layer_1).unsqueeze(0).to(self.device), |
|
torch.tensor(layer_2).unsqueeze(0).to(self.device), |
|
torch.tensor(layer_3).unsqueeze(0).to(self.device) |
|
] |
|
|
|
|
|
audio_hat = self.snac_model.decode(codes) |
|
return audio_hat |
|
|
|
def postprocess(self, generated_ids): |
|
|
|
if self.voice_cloning: |
|
""" |
|
For cloning applications, use this postprocess function to get generated audio samples |
|
""" |
|
|
|
code_lists = self.prepare_audio_tokens_for_decoder(generated_ids) |
|
|
|
|
|
temp = self.convert_codes_to_waveform(code_lists[0]) |
|
audio_sample = temp.detach().squeeze().to("cpu").numpy() |
|
|
|
else: |
|
""" |
|
Process generated tokens into audio |
|
""" |
|
|
|
token_indices = (generated_ids == self.START_OF_SPEECH).nonzero(as_tuple=True) |
|
|
|
if len(token_indices[1]) > 0: |
|
last_occurrence_idx = token_indices[1][-1].item() |
|
cropped_tensor = generated_ids[:, last_occurrence_idx+1:] |
|
else: |
|
cropped_tensor = generated_ids |
|
|
|
|
|
processed_rows = [] |
|
for row in cropped_tensor: |
|
masked_row = row[row != self.END_OF_SPEECH] |
|
processed_rows.append(masked_row) |
|
|
|
code_lists = self.prepare_audio_tokens_for_decoder(processed_rows) |
|
|
|
|
|
audio_samples = [] |
|
for code_list in code_lists: |
|
if len(code_list) > 0: |
|
audio = self.convert_codes_to_waveform(code_list) |
|
audio_samples.append(audio) |
|
else: |
|
raise ValueError("Empty code list, no audio to generate") |
|
|
|
if not audio_samples: |
|
return {"error": "No audio samples generated"} |
|
|
|
|
|
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy() |
|
|
|
|
|
audio_int16 = (audio_sample * 32767).astype(np.int16) |
|
|
|
|
|
buffer = io.BytesIO() |
|
sf.write(buffer, audio_sample, samplerate=24000, format='WAV', subtype='PCM_16') |
|
buffer.seek(0) |
|
|
|
|
|
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8') |
|
|
|
return { |
|
"audio_sample": audio_sample, |
|
"audio_b64": audio_b64, |
|
"sample_rate": 24000, |
|
} |