|
import os |
|
import torch |
|
import numpy as np |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from snac import SNAC |
|
import logging |
|
import json |
|
import base64 |
|
import io |
|
import wave |
|
from threading import Thread |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
logger.info("Initializing Orpheus TTS handler") |
|
|
|
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit" |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
logger.info(f"Model loaded on {self.device}") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
logger.info("Tokenizer loaded") |
|
|
|
|
|
try: |
|
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
self.snac_model.to(self.device) |
|
logger.info("SNAC model loaded") |
|
except Exception as e: |
|
logger.error(f"Error loading SNAC: {str(e)}") |
|
raise |
|
|
|
|
|
self.start_token = torch.tensor([[128259]], dtype=torch.int64) |
|
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) |
|
self.start_audio_token = 128257 |
|
self.end_audio_token = 128258 |
|
|
|
self._warmed_up = False |
|
logger.info("Handler initialization complete") |
|
|
|
def preprocess(self, data): |
|
"""Preprocess input data before inference.""" |
|
logger.info(f"Preprocessing data: {type(data)}") |
|
|
|
|
|
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"): |
|
return {"health_check": True} |
|
|
|
if isinstance(data, dict) and "inputs" in data: |
|
text = data["inputs"] |
|
parameters = data.get("parameters", {}) |
|
else: |
|
text = data |
|
parameters = {} |
|
|
|
|
|
voice = parameters.get("voice", "tara") |
|
temperature = float(parameters.get("temperature", 0.6)) |
|
top_p = float(parameters.get("top_p", 0.95)) |
|
max_new_tokens = int(parameters.get("max_new_tokens", 1200)) |
|
repetition_penalty = float(parameters.get("repetition_penalty", 1.1)) |
|
stream = parameters.get("stream", False) |
|
|
|
prompt = f"{voice}: {text}" |
|
logger.info(f"Formatted prompt with voice {voice}") |
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids |
|
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1) |
|
input_ids = modified_input_ids.to(self.device) |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
"stream": stream, |
|
"health_check": False |
|
} |
|
|
|
def inference(self, inputs): |
|
"""Run model inference (non-streaming).""" |
|
if inputs.get("health_check", False): |
|
return {"status": "ok"} |
|
|
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
|
|
logger.info(f"Running non-streaming inference with max_new_tokens={inputs['max_new_tokens']}") |
|
with torch.no_grad(): |
|
generated_ids = self.model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=inputs['max_new_tokens'], |
|
do_sample=True, |
|
temperature=inputs['temperature'], |
|
top_p=inputs['top_p'], |
|
repetition_penalty=inputs['repetition_penalty'], |
|
num_return_sequences=1, |
|
eos_token_id=self.end_audio_token, |
|
) |
|
logger.info(f"Generation complete, output shape: {generated_ids.shape}") |
|
return generated_ids |
|
|
|
def postprocess(self, generated_ids): |
|
"""Process generated tokens into a single audio file (non-streaming).""" |
|
if isinstance(generated_ids, dict) and "status" in generated_ids: |
|
return generated_ids |
|
|
|
logger.info("Postprocessing generated tokens for non-streaming output") |
|
|
|
|
|
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True) |
|
if len(token_indices[1]) > 0: |
|
last_occurrence_idx = token_indices[1][-1].item() |
|
cropped_tensor = generated_ids[:, last_occurrence_idx + 1:] |
|
else: |
|
cropped_tensor = generated_ids |
|
logger.warning("No start audio token found in non-streaming output") |
|
|
|
|
|
code_list = [t.item() for t in cropped_tensor.squeeze() if t.item() != self.end_audio_token] |
|
|
|
|
|
audio_b64 = self._decode_audio_chunk(code_list) |
|
if not audio_b64: |
|
return {"error": "No audio samples generated"} |
|
|
|
logger.info(f"Audio encoded as base64, length: {len(audio_b64)}") |
|
return { |
|
"audio_b64": audio_b64, |
|
"sample_rate": 24000 |
|
} |
|
|
|
def _decode_audio_chunk(self, code_list): |
|
"""Decodes a list of token codes into a base64 WAV string.""" |
|
if not code_list: |
|
return None |
|
|
|
|
|
new_length = (len(code_list) // 7) * 7 |
|
if new_length == 0: |
|
return None |
|
|
|
trimmed_list = code_list[:new_length] |
|
|
|
|
|
adjusted_codes = [t - 128266 for t in trimmed_list] |
|
|
|
|
|
audio = self.redistribute_codes(adjusted_codes) |
|
audio_sample = audio.detach().squeeze().cpu().numpy() |
|
|
|
|
|
audio_int16 = (audio_sample * 32767).astype(np.int16) |
|
|
|
|
|
with io.BytesIO() as wav_io: |
|
with wave.open(wav_io, 'wb') as wav_file: |
|
wav_file.setnchannels(1) |
|
wav_file.setsampwidth(2) |
|
wav_file.setframerate(24000) |
|
wav_file.writeframes(audio_int16.tobytes()) |
|
wav_data = wav_io.getvalue() |
|
|
|
return base64.b64encode(wav_data).decode('utf-8') |
|
|
|
def _stream_inference(self, inputs): |
|
"""Generator function for streaming inference.""" |
|
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
generation_kwargs = dict( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
streamer=streamer, |
|
max_new_tokens=inputs['max_new_tokens'], |
|
do_sample=True, |
|
temperature=inputs['temperature'], |
|
top_p=inputs['top_p'], |
|
repetition_penalty=inputs['repetition_penalty'], |
|
eos_token_id=self.end_audio_token, |
|
) |
|
|
|
|
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
logger.info("Starting streaming inference...") |
|
token_buffer = [] |
|
found_start_audio = False |
|
|
|
for new_text in streamer: |
|
|
|
|
|
new_tokens = self.tokenizer.encode(new_text, add_special_tokens=False) |
|
|
|
for token in new_tokens: |
|
if not found_start_audio: |
|
if token == self.start_audio_token: |
|
found_start_audio = True |
|
continue |
|
|
|
if token == self.end_audio_token: |
|
|
|
if len(token_buffer) >= 7: |
|
audio_b64 = self._decode_audio_chunk(token_buffer) |
|
if audio_b64: |
|
yield f"data: {json.dumps({'audio_b64': audio_b64})}\n\n" |
|
logger.info("End of audio token found. Stopping stream.") |
|
return |
|
|
|
token_buffer.append(token) |
|
|
|
|
|
if len(token_buffer) >= 7: |
|
process_len = (len(token_buffer) // 7) * 7 |
|
codes_to_process = token_buffer[:process_len] |
|
token_buffer = token_buffer[process_len:] |
|
|
|
audio_b64 = self._decode_audio_chunk(codes_to_process) |
|
if audio_b64: |
|
logger.info(f"Yielding audio chunk of {len(codes_to_process)} tokens") |
|
yield f"data: {json.dumps({'audio_b64': audio_b64})}\n\n" |
|
|
|
|
|
if token_buffer: |
|
audio_b64 = self._decode_audio_chunk(token_buffer) |
|
if audio_b64: |
|
yield f"data: {json.dumps({'audio_b64': audio_b64})}\n\n" |
|
|
|
logger.info("Streaming complete.") |
|
|
|
|
|
def redistribute_codes(self, code_list): |
|
"""Reorganize tokens for SNAC decoding.""" |
|
layer_1, layer_2, layer_3 = [], [], [] |
|
num_groups = len(code_list) // 7 |
|
for i in range(num_groups): |
|
idx = 7 * i |
|
layer_1.append(code_list[idx]) |
|
layer_2.append(code_list[idx + 1] - 4096) |
|
layer_3.append(code_list[idx + 2] - (2 * 4096)) |
|
layer_3.append(code_list[idx + 3] - (3 * 4096)) |
|
layer_2.append(code_list[idx + 4] - (4 * 4096)) |
|
layer_3.append(code_list[idx + 5] - (5 * 4096)) |
|
layer_3.append(code_list[idx + 6] - (6 * 4096)) |
|
|
|
codes = [ |
|
torch.tensor(layer_1).unsqueeze(0).to(self.device), |
|
torch.tensor(layer_2).unsqueeze(0).to(self.device), |
|
torch.tensor(layer_3).unsqueeze(0).to(self.device) |
|
] |
|
return self.snac_model.decode(codes) |
|
|
|
def __call__(self, data): |
|
"""Main entry point for the handler.""" |
|
if not self._warmed_up: |
|
self._warmup() |
|
|
|
try: |
|
|
|
if data == "ping" or (isinstance(data, dict) and data.get("inputs") == "ping"): |
|
logger.info("Processing health check request") |
|
return {"status": "ok"} |
|
|
|
preprocessed_inputs = self.preprocess(data) |
|
|
|
|
|
if preprocessed_inputs.get("stream"): |
|
return self._stream_inference(preprocessed_inputs) |
|
else: |
|
model_outputs = self.inference(preprocessed_inputs) |
|
return self.postprocess(model_outputs) |
|
|
|
except Exception as e: |
|
import traceback |
|
logger.error(f"Error processing request: {str(e)}\n{traceback.format_exc()}") |
|
return {"error": str(e), "traceback": traceback.format_exc()} |
|
|
|
def _warmup(self): |
|
try: |
|
logger.info("Warming up model...") |
|
dummy_prompt = "tara: Hello" |
|
input_ids = self.tokenizer(dummy_prompt, return_tensors="pt").input_ids.to(self.device) |
|
_ = self.model.generate(input_ids=input_ids, max_new_tokens=10) |
|
self._warmed_up = True |
|
logger.info("Warmup complete.") |
|
except Exception as e: |
|
logger.error(f"[WARMUP ERROR] {str(e)}") |