Update handler.py
Browse files- handler.py +125 -36
handler.py
CHANGED
@@ -1,40 +1,129 @@
|
|
1 |
-
# handler.py
|
2 |
import os
|
3 |
-
import
|
|
|
|
|
|
|
4 |
import torch
|
5 |
-
from
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
def
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
# generate raw PCM chunks
|
26 |
-
pcm_chunks = model.generate_speech(prompt=text)
|
27 |
-
pcm_bytes = b"".join(pcm_chunks)
|
28 |
-
|
29 |
-
# wrap in a 24 kHz 16-bit WAV header
|
30 |
-
import io, wave
|
31 |
-
buf = io.BytesIO()
|
32 |
-
with wave.open(buf, "wb") as wf:
|
33 |
-
wf.setnchannels(1)
|
34 |
-
wf.setsampwidth(2)
|
35 |
-
wf.setframerate(24000)
|
36 |
-
wf.writeframes(pcm_bytes)
|
37 |
-
wav = buf.getvalue()
|
38 |
-
b64 = base64.b64encode(wav).decode("utf-8")
|
39 |
-
return { "audio_base64": b64 }
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
import asyncio
|
5 |
+
import numpy as np
|
6 |
import torch
|
7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
8 |
+
from snac import SNAC
|
9 |
+
from livekit import rtc, api
|
10 |
|
11 |
+
class EndpointHandler:
|
12 |
+
def __init__(self, path: str = ""):
|
13 |
+
# Load the Orpheus TTS model and tokenizer from the given path (Hub repository).
|
14 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
16 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
18 |
+
)
|
19 |
+
self.model.to(self.device)
|
20 |
+
self.model.eval()
|
21 |
+
# Load the SNAC audio codec model for decoding audio tokens (24 kHz speech model).
|
22 |
+
self.audio_codec = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(self.device)
|
23 |
+
self.audio_codec.eval()
|
24 |
+
# Store LiveKit credentials from environment (if provided).
|
25 |
+
self.livekit_url = os.getenv("LIVEKIT_URL")
|
26 |
+
self.livekit_api_key = os.getenv("LIVEKIT_API_KEY")
|
27 |
+
self.livekit_api_secret = os.getenv("LIVEKIT_API_SECRET")
|
28 |
+
self.livekit_room = os.getenv("LIVEKIT_ROOM") # default room name (optional)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
def __call__(self, data: dict) -> list:
|
31 |
+
# Extract input text and optional voice and LiveKit parameters.
|
32 |
+
text_input = data.get("inputs") or data.get("text") or ""
|
33 |
+
if not isinstance(text_input, str) or text_input.strip() == "":
|
34 |
+
raise ValueError("No text input provided for TTS")
|
35 |
+
voice = data.get("voice", "tara") # default voice (e.g., "tara")
|
36 |
+
# Format prompt with voice name (Orpheus expects prompts like "voice: text").
|
37 |
+
prompt = f"{voice}: {text_input}"
|
38 |
+
|
39 |
+
# Encode prompt and generate output tokens with the TTS model.
|
40 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
41 |
+
generate_kwargs = {
|
42 |
+
"max_new_tokens": 1024, # allow sufficient tokens for audio output
|
43 |
+
"do_sample": True,
|
44 |
+
"temperature": 0.8,
|
45 |
+
"top_p": 0.95,
|
46 |
+
"repetition_penalty": 1.1, # >=1.1 for stable speech generation:contentReference[oaicite:2]{index=2}
|
47 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
48 |
+
}
|
49 |
+
output_ids = self.model.generate(input_ids, **generate_kwargs)
|
50 |
+
# The generated sequence includes the prompt; isolate newly generated tokens:
|
51 |
+
generated_tokens = output_ids[0, input_ids.size(1):]
|
52 |
+
output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
|
53 |
+
|
54 |
+
# Extract audio token IDs (SNAC codec tokens) from the model output.
|
55 |
+
audio_token_ids = [int(m) for m in re.findall(r"<custom_token_(\d+)>", output_text)]
|
56 |
+
if not audio_token_ids:
|
57 |
+
return [{"error": "TTS model produced no audio tokens"}]
|
58 |
+
|
59 |
+
# Convert list of token IDs into SNAC codec input tensors (7 tokens per audio frame):contentReference[oaicite:3]{index=3}.
|
60 |
+
# If the number of tokens is not a multiple of 7, pad with zeros (silence) to complete the last frame.
|
61 |
+
if len(audio_token_ids) % 7 != 0:
|
62 |
+
pad_len = 7 - (len(audio_token_ids) % 7)
|
63 |
+
audio_token_ids.extend([0] * pad_len)
|
64 |
+
audio_ids = torch.tensor(audio_token_ids, dtype=torch.int32, device=self.device).reshape(-1, 7)
|
65 |
+
# Separate hierarchical codec codebooks: coarse (level 0), mid (level 1), fine (level 2).
|
66 |
+
codes_0 = audio_ids[:, 0].unsqueeze(0) # shape (1, N_frames)
|
67 |
+
codes_1 = torch.stack((audio_ids[:, 1], audio_ids[:, 4]), dim=1).flatten().unsqueeze(0)
|
68 |
+
codes_2 = torch.stack((audio_ids[:, 2], audio_ids[:, 3], audio_ids[:, 5], audio_ids[:, 6]), dim=1).flatten().unsqueeze(0)
|
69 |
+
|
70 |
+
# Decode audio tokens to waveform audio using SNAC codec model.
|
71 |
+
with torch.inference_mode():
|
72 |
+
audio_wave = self.audio_codec.decode([codes_0, codes_1, codes_2])
|
73 |
+
audio_wave = audio_wave.squeeze().cpu().numpy() # shape: (num_samples,)
|
74 |
+
# Convert waveform from float (-1.0 to 1.0) to 16-bit PCM samples.
|
75 |
+
audio_pcm = (audio_wave * 32767.0).astype(np.int16)
|
76 |
+
sample_rate = 24000 # Hz (SNAC 24 kHz model output)
|
77 |
+
num_channels = 1
|
78 |
+
|
79 |
+
# Determine LiveKit connection info (from request or env).
|
80 |
+
lk_url = data.get("livekit_url", self.livekit_url)
|
81 |
+
lk_token = data.get("livekit_token", None)
|
82 |
+
room_name = data.get("livekit_room", self.livekit_room)
|
83 |
+
identity = data.get("livekit_identity", f"tts-agent-{int(time.time())}")
|
84 |
+
participant_name = data.get("livekit_name", "TTS Agent")
|
85 |
+
if not lk_token:
|
86 |
+
# If no direct token is provided, generate one using API key/secret.
|
87 |
+
if not (lk_url and self.livekit_api_key and self.livekit_api_secret and room_name):
|
88 |
+
return [{"error": "LiveKit connection information missing"}]
|
89 |
+
token_builder = api.AccessToken(self.livekit_api_key, self.livekit_api_secret)
|
90 |
+
token_builder.with_identity(identity).with_name(participant_name)
|
91 |
+
token_builder.with_grants(api.VideoGrants(room_join=True, room=room_name))
|
92 |
+
lk_token = token_builder.to_jwt()
|
93 |
+
|
94 |
+
# Asynchronous function to connect to LiveKit and stream audio frames.
|
95 |
+
async def stream_audio():
|
96 |
+
room = rtc.Room()
|
97 |
+
try:
|
98 |
+
await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True))
|
99 |
+
except Exception as e:
|
100 |
+
return f"Failed to connect to LiveKit: {e}"
|
101 |
+
# Create an audio track for streaming the TTS output.
|
102 |
+
source = rtc.AudioSource(sample_rate, num_channels)
|
103 |
+
track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source)
|
104 |
+
await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio"))
|
105 |
+
# Stream the audio in chunks for real-time playback.
|
106 |
+
frame_duration = 0.05 # 50 ms per frame
|
107 |
+
frame_samples = int(sample_rate * frame_duration)
|
108 |
+
total_samples = len(audio_pcm)
|
109 |
+
for start in range(0, total_samples, frame_samples):
|
110 |
+
end = min(start + frame_samples, total_samples)
|
111 |
+
chunk = audio_pcm[start:end]
|
112 |
+
# Create an AudioFrame and copy the PCM chunk into it:contentReference[oaicite:4]{index=4}.
|
113 |
+
frame = rtc.AudioFrame.create(sample_rate=sample_rate, num_channels=num_channels,
|
114 |
+
samples_per_channel=len(chunk) // num_channels)
|
115 |
+
frame_buffer = np.frombuffer(frame.data, dtype=np.int16)
|
116 |
+
np.copyto(frame_buffer[:len(chunk)], chunk)
|
117 |
+
await source.capture_frame(frame)
|
118 |
+
# Sleep to maintain real-time pace (synchronize with frame duration).
|
119 |
+
await asyncio.sleep(frame_duration)
|
120 |
+
# Disconnect from the room after streaming is finished.
|
121 |
+
await room.disconnect()
|
122 |
+
return None
|
123 |
+
|
124 |
+
# Run the streaming coroutine and wait for completion.
|
125 |
+
error = asyncio.run(stream_audio())
|
126 |
+
if error:
|
127 |
+
return [{"error": error}]
|
128 |
+
# Return a success status (audio is delivered via LiveKit, not in the HTTP response).
|
129 |
+
return [{"status": "success"}]
|