hypaai_orpheus / handler.py
okezieowen's picture
Update handler.py
031c41b verified
import os
import torch
import numpy as np
import librosa
import soundfile as sf
import traceback
import base64
import io
import wave
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC
from vllm import LLM, SamplingParams
class EndpointHandler:
def __init__(self, path=""):
# Delimiter tokens as defined in Orpheus' vocabulary
self.START_OF_HUMAN = 128259
self.START_OF_TEXT = 128000
self.END_OF_TEXT = 128009
self.END_OF_HUMAN = 128260
self.START_OF_AI = 128261
self.START_OF_SPEECH = 128257
self.END_OF_SPEECH = 128258
self.END_OF_AI = 128262
self.AUDIO_TOKENS_START = 128266
# Load the models and tokenizer
self.model = LLM(path, max_model_len = 4096, gpu_memory_utilization = 0.3)
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Move to devices
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# self.model.to(self.device)
# Load SNAC model for audio decoding
try:
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
self.snac_model.to(self.device)
except Exception as e:
raise RuntimeError(f"Failed to load SNAC model: {e}")
# Set up functions to format and encode text/audio
def encode_text(self, text):
return self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)
def encode_audio(self, base64_audio_str):
audio_bytes = base64.b64decode(base64_audio_str)
audio_buffer = io.BytesIO(audio_bytes)
waveform, sr = sf.read(audio_buffer, dtype='float32')
if waveform.ndim > 1:
waveform = np.mean(waveform, axis=1)
if sr != 24000:
waveform = librosa.resample(waveform, orig_sr=sr, target_sr=24000)
return self.tokenize_audio(waveform)
def format_text_block(self, text_ids):
return [
torch.tensor([[self.START_OF_HUMAN]], dtype=torch.int64),
torch.tensor([[self.START_OF_TEXT]], dtype=torch.int64),
text_ids,
torch.tensor([[self.END_OF_TEXT]], dtype=torch.int64),
torch.tensor([[self.END_OF_HUMAN]], dtype=torch.int64)
]
def format_audio_block(self, audio_codes):
return [
torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64),
torch.tensor([audio_codes], dtype=torch.int64),
torch.tensor([[self.END_OF_SPEECH]], dtype=torch.int64),
torch.tensor([[self.END_OF_AI]], dtype=torch.int64)
]
def enroll_user(self, enrollment_pairs):
"""
Parameters:
- enrollment_pairs: List of tuples (text, audio_data), where audio_data is
base64-encoded audio data
Returns:
- cloning_features (str): serialized enrollment data
"""
enrollment_data = []
for text, base64_audio in enrollment_pairs:
text_ids = self.encode_text(text).cpu()
audio_codes = self.encode_audio(base64_audio)
enrollment_data.append({
"text_ids": text_ids,
"audio_codes": audio_codes
})
# Serialize enrollment data
buffer = io.BytesIO()
torch.save(enrollment_data, buffer)
buffer.seek(0)
# Encode as base64 string and assign to attribute
cloning_features = base64.b64encode(buffer.read()).decode('utf-8')
return cloning_features
def prepare_audio_tokens_for_decoder(self, audio_codes_list):
"""
Given a list containing sequences of generated audio codes, do the following:
1. Trim length to a multiple of 7 (SNAC decoder requires 7 tokens per audio frame)
2. Adjust token values to SNAC decoder's expected range
"""
modified_audio_codes_list = []
for audio_codes in audio_codes_list:
# Trim length to a multiple of 7
length = (audio_codes.size(0) // 7) * 7
trimmed = audio_codes[:length]
# Adjust token values to SNAC decoder's expected range
audio_codes = trimmed - self.AUDIO_TOKENS_START
# Add modified audio codes to list
modified_audio_codes_list.append(audio_codes)
return modified_audio_codes_list
# Convert audio sample to codes and reconstruct
def tokenize_audio(self, waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0).unsqueeze(0).to(self.device)
with torch.inference_mode():
codes = self.snac_model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][(1 * i) + 0].item() + self.AUDIO_TOKENS_START + (0 * 4096))
all_codes.append(codes[1][0][(2 * i) + 0].item() + self.AUDIO_TOKENS_START + (1 * 4096))
all_codes.append(codes[2][0][(4 * i) + 0].item() + self.AUDIO_TOKENS_START + (2 * 4096))
all_codes.append(codes[2][0][(4 * i) + 1].item() + self.AUDIO_TOKENS_START + (3 * 4096))
all_codes.append(codes[1][0][(2 * i) + 1].item() + self.AUDIO_TOKENS_START + (4 * 4096))
all_codes.append(codes[2][0][(4 * i) + 2].item() + self.AUDIO_TOKENS_START + (5 * 4096))
all_codes.append(codes[2][0][(4 * i) + 3].item() + self.AUDIO_TOKENS_START + (6 * 4096))
return all_codes
def preprocess(self, data):
# Preprocess input data before inference
self.voice_cloning = data.get("clone", False)
# Extract parameters from request
target_text = data["inputs"]
parameters = data.get("parameters", {})
cloning_features = data.get("cloning_features", None)
temperature = float(parameters.get("temperature", 0.6))
top_p = float(parameters.get("top_p", 0.95))
max_new_tokens = int(parameters.get("max_new_tokens", 1200))
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
if self.voice_cloning:
"""Handle voice cloning using cloning features"""
if not cloning_features:
raise ValueError("No cloning features were provided")
else:
# Decode back into tensors
enrollment_data = torch.load(io.BytesIO(base64.b64decode(cloning_features)))
# Process pre-tokenized enrollment_data
input_sequence = []
for item in enrollment_data:
text_ids = item["text_ids"]
audio_codes = item["audio_codes"]
input_sequence.extend(self.format_text_block(text_ids))
input_sequence.extend(self.format_audio_block(audio_codes))
# Append target text whose audio we want
target_text_ids = self.encode_text(target_text)
input_sequence.extend(self.format_text_block(target_text_ids))
# Start of target audio - audio codes to be completed by model
input_sequence.extend([
torch.tensor([[self.START_OF_AI]], dtype=torch.int64),
torch.tensor([[self.START_OF_SPEECH]], dtype=torch.int64)
])
# Final input tensor
input_ids = torch.cat(input_sequence, dim=1)
# Heuristic to determine max_new_tokens based on empirical relationship
# between the length of the prompt ids and the length of the generated ids
prompt_ids = self.encode_text(target_text)
max_new_tokens = int(prompt_ids.size()[1] * 20 + 200)
input_ids = input_ids.to(self.device)
else:
# Handle standard text-to-speech
# Extract parameters from request
voice = parameters.get("voice", "Eniola")
prompt = f"{voice}: {target_text}"
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Add special tokens
input_ids = torch.cat(self.format_text_block(input_ids), dim=1)
# No need for padding as we're processing a single sequence
input_ids = input_ids.to(self.device)
return {
"input_ids": input_ids,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty,
}
def inference(self, inputs):
"""
Run model inference on the preprocessed inputs
"""
# Extract parameters
input_ids = inputs["input_ids"]
sampling_params = SamplingParams(
temperature = inputs["temperature"],
top_p = inputs["top_p"],
max_tokens = inputs["max_new_tokens"],
repetition_penalty = inputs["repetition_penalty"],
stop_token_ids = [self.END_OF_SPEECH],
)
prompt_string = self.tokenizer.decode(input_ids[0])
# Forward pass through the model
generated_ids = self.model.generate(prompt_string, sampling_params)
return torch.tensor(generated_ids[0].outputs[0].token_ids).unsqueeze(0)
def __call__(self, data):
# Main entry point for the handler
try:
enroll_user = data.get("enroll_user", False)
if enroll_user:
# We extract cloning features for enrollment
enrollment_pairs = data.get("enrollments", [])
cloning_features = self.enroll_user(enrollment_pairs)
return {"cloning_features": cloning_features}
else:
# We want to generate speech using preset cloning features
preprocessed_inputs = self.preprocess(data)
model_outputs = self.inference(preprocessed_inputs)
response = self.postprocess(model_outputs)
return response
# Catch that error, baby
except Exception as e:
traceback.print_exc()
return {"error": str(e)}
# Postprocess generated ids
def convert_codes_to_waveform(self, code_list):
"""
Reorganize tokens for SNAC decoding
"""
layer_1 = [] # Coarsest layer
layer_2 = [] # Intermediate layer
layer_3 = [] # Finest layer
num_groups = len(code_list) // 7
for i in range(num_groups):
idx = 7 * i
layer_1.append(code_list[7 * i + 0] - (0 * 4096))
layer_2.append(code_list[7 * i + 1] - (1 * 4096))
layer_3.append(code_list[7 * i + 2] - (2 * 4096))
layer_3.append(code_list[7 * i + 3] - (3 * 4096))
layer_2.append(code_list[7 * i + 4] - (4 * 4096))
layer_3.append(code_list[7 * i + 5] - (5 * 4096))
layer_3.append(code_list[7 * i + 6] - (6 * 4096))
codes = [
torch.tensor(layer_1).unsqueeze(0).to(self.device),
torch.tensor(layer_2).unsqueeze(0).to(self.device),
torch.tensor(layer_3).unsqueeze(0).to(self.device)
]
# Decode audio
audio_hat = self.snac_model.decode(codes)
return audio_hat
def postprocess(self, generated_ids):
if self.voice_cloning:
"""
For cloning applications, use this postprocess function to get generated audio samples
"""
# Modify audio codes to be digestible byb SNAC decoder
code_lists = self.prepare_audio_tokens_for_decoder(generated_ids)
# Generate audio from codes
temp = self.convert_codes_to_waveform(code_lists[0])
audio_sample = temp.detach().squeeze().to("cpu").numpy()
else:
"""
Process generated tokens into audio
"""
# Find Start of Audio token
token_indices = (generated_ids == self.START_OF_SPEECH).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
cropped_tensor = generated_ids
# Remove End of Audio tokens
processed_rows = []
for row in cropped_tensor:
masked_row = row[row != self.END_OF_SPEECH]
processed_rows.append(masked_row)
code_lists = self.prepare_audio_tokens_for_decoder(processed_rows)
# Generate audio from codes
audio_samples = []
for code_list in code_lists:
if len(code_list) > 0:
audio = self.convert_codes_to_waveform(code_list)
audio_samples.append(audio)
else:
raise ValueError("Empty code list, no audio to generate")
if not audio_samples:
return {"error": "No audio samples generated"}
# Return first (and only) audio sample
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
# Convert float32 array to int16 for WAV format
audio_int16 = (audio_sample * 32767).astype(np.int16)
# Write to WAV in memory (float32 or int16 depending on your preference)
buffer = io.BytesIO()
sf.write(buffer, audio_sample, samplerate=24000, format='WAV', subtype='PCM_16') # or PCM_32
buffer.seek(0)
# Encode WAV bytes as base64
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
return {
"audio_sample": audio_sample,
"audio_b64": audio_b64,
"sample_rate": 24000,
}