atharva27 commited on
Commit
3a8dbfa
·
verified ·
1 Parent(s): 9cc3956

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +125 -36
handler.py CHANGED
@@ -1,40 +1,129 @@
1
- # handler.py
2
  import os
3
- import base64
 
 
 
4
  import torch
5
- from orpheus_tts import OrpheusModel
 
 
6
 
7
- # This is called once at container startup
8
- def init():
9
- global model, device
10
- # pick CUDA if available, else CPU
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- model = OrpheusModel(
13
- model_name="canopylabs/orpheus-3b-0.1-ft",
14
- dtype=torch.bfloat16
15
- ).to(device)
16
-
17
- # This is called on each HTTP request
18
- def inference(request):
19
- """
20
- Expects JSON: { "text": "Hello world!" }
21
- Returns JSON: { "audio_base64": "<base64-wav-bytes>" }
22
- """
23
- payload = request.json()
24
- text = payload.get("text", "")
25
- # generate raw PCM chunks
26
- pcm_chunks = model.generate_speech(prompt=text)
27
- pcm_bytes = b"".join(pcm_chunks)
28
-
29
- # wrap in a 24 kHz 16-bit WAV header
30
- import io, wave
31
- buf = io.BytesIO()
32
- with wave.open(buf, "wb") as wf:
33
- wf.setnchannels(1)
34
- wf.setsampwidth(2)
35
- wf.setframerate(24000)
36
- wf.writeframes(pcm_bytes)
37
- wav = buf.getvalue()
38
- b64 = base64.b64encode(wav).decode("utf-8")
39
- return { "audio_base64": b64 }
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
+ import time
4
+ import asyncio
5
+ import numpy as np
6
  import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from snac import SNAC
9
+ from livekit import rtc, api
10
 
11
+ class EndpointHandler:
12
+ def __init__(self, path: str = ""):
13
+ # Load the Orpheus TTS model and tokenizer from the given path (Hub repository).
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
16
+ self.model = AutoModelForCausalLM.from_pretrained(
17
+ path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
18
+ )
19
+ self.model.to(self.device)
20
+ self.model.eval()
21
+ # Load the SNAC audio codec model for decoding audio tokens (24 kHz speech model).
22
+ self.audio_codec = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(self.device)
23
+ self.audio_codec.eval()
24
+ # Store LiveKit credentials from environment (if provided).
25
+ self.livekit_url = os.getenv("LIVEKIT_URL")
26
+ self.livekit_api_key = os.getenv("LIVEKIT_API_KEY")
27
+ self.livekit_api_secret = os.getenv("LIVEKIT_API_SECRET")
28
+ self.livekit_room = os.getenv("LIVEKIT_ROOM") # default room name (optional)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ def __call__(self, data: dict) -> list:
31
+ # Extract input text and optional voice and LiveKit parameters.
32
+ text_input = data.get("inputs") or data.get("text") or ""
33
+ if not isinstance(text_input, str) or text_input.strip() == "":
34
+ raise ValueError("No text input provided for TTS")
35
+ voice = data.get("voice", "tara") # default voice (e.g., "tara")
36
+ # Format prompt with voice name (Orpheus expects prompts like "voice: text").
37
+ prompt = f"{voice}: {text_input}"
38
+
39
+ # Encode prompt and generate output tokens with the TTS model.
40
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
41
+ generate_kwargs = {
42
+ "max_new_tokens": 1024, # allow sufficient tokens for audio output
43
+ "do_sample": True,
44
+ "temperature": 0.8,
45
+ "top_p": 0.95,
46
+ "repetition_penalty": 1.1, # >=1.1 for stable speech generation:contentReference[oaicite:2]{index=2}
47
+ "pad_token_id": self.tokenizer.eos_token_id,
48
+ }
49
+ output_ids = self.model.generate(input_ids, **generate_kwargs)
50
+ # The generated sequence includes the prompt; isolate newly generated tokens:
51
+ generated_tokens = output_ids[0, input_ids.size(1):]
52
+ output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
53
+
54
+ # Extract audio token IDs (SNAC codec tokens) from the model output.
55
+ audio_token_ids = [int(m) for m in re.findall(r"<custom_token_(\d+)>", output_text)]
56
+ if not audio_token_ids:
57
+ return [{"error": "TTS model produced no audio tokens"}]
58
+
59
+ # Convert list of token IDs into SNAC codec input tensors (7 tokens per audio frame):contentReference[oaicite:3]{index=3}.
60
+ # If the number of tokens is not a multiple of 7, pad with zeros (silence) to complete the last frame.
61
+ if len(audio_token_ids) % 7 != 0:
62
+ pad_len = 7 - (len(audio_token_ids) % 7)
63
+ audio_token_ids.extend([0] * pad_len)
64
+ audio_ids = torch.tensor(audio_token_ids, dtype=torch.int32, device=self.device).reshape(-1, 7)
65
+ # Separate hierarchical codec codebooks: coarse (level 0), mid (level 1), fine (level 2).
66
+ codes_0 = audio_ids[:, 0].unsqueeze(0) # shape (1, N_frames)
67
+ codes_1 = torch.stack((audio_ids[:, 1], audio_ids[:, 4]), dim=1).flatten().unsqueeze(0)
68
+ codes_2 = torch.stack((audio_ids[:, 2], audio_ids[:, 3], audio_ids[:, 5], audio_ids[:, 6]), dim=1).flatten().unsqueeze(0)
69
+
70
+ # Decode audio tokens to waveform audio using SNAC codec model.
71
+ with torch.inference_mode():
72
+ audio_wave = self.audio_codec.decode([codes_0, codes_1, codes_2])
73
+ audio_wave = audio_wave.squeeze().cpu().numpy() # shape: (num_samples,)
74
+ # Convert waveform from float (-1.0 to 1.0) to 16-bit PCM samples.
75
+ audio_pcm = (audio_wave * 32767.0).astype(np.int16)
76
+ sample_rate = 24000 # Hz (SNAC 24 kHz model output)
77
+ num_channels = 1
78
+
79
+ # Determine LiveKit connection info (from request or env).
80
+ lk_url = data.get("livekit_url", self.livekit_url)
81
+ lk_token = data.get("livekit_token", None)
82
+ room_name = data.get("livekit_room", self.livekit_room)
83
+ identity = data.get("livekit_identity", f"tts-agent-{int(time.time())}")
84
+ participant_name = data.get("livekit_name", "TTS Agent")
85
+ if not lk_token:
86
+ # If no direct token is provided, generate one using API key/secret.
87
+ if not (lk_url and self.livekit_api_key and self.livekit_api_secret and room_name):
88
+ return [{"error": "LiveKit connection information missing"}]
89
+ token_builder = api.AccessToken(self.livekit_api_key, self.livekit_api_secret)
90
+ token_builder.with_identity(identity).with_name(participant_name)
91
+ token_builder.with_grants(api.VideoGrants(room_join=True, room=room_name))
92
+ lk_token = token_builder.to_jwt()
93
+
94
+ # Asynchronous function to connect to LiveKit and stream audio frames.
95
+ async def stream_audio():
96
+ room = rtc.Room()
97
+ try:
98
+ await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True))
99
+ except Exception as e:
100
+ return f"Failed to connect to LiveKit: {e}"
101
+ # Create an audio track for streaming the TTS output.
102
+ source = rtc.AudioSource(sample_rate, num_channels)
103
+ track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source)
104
+ await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio"))
105
+ # Stream the audio in chunks for real-time playback.
106
+ frame_duration = 0.05 # 50 ms per frame
107
+ frame_samples = int(sample_rate * frame_duration)
108
+ total_samples = len(audio_pcm)
109
+ for start in range(0, total_samples, frame_samples):
110
+ end = min(start + frame_samples, total_samples)
111
+ chunk = audio_pcm[start:end]
112
+ # Create an AudioFrame and copy the PCM chunk into it:contentReference[oaicite:4]{index=4}.
113
+ frame = rtc.AudioFrame.create(sample_rate=sample_rate, num_channels=num_channels,
114
+ samples_per_channel=len(chunk) // num_channels)
115
+ frame_buffer = np.frombuffer(frame.data, dtype=np.int16)
116
+ np.copyto(frame_buffer[:len(chunk)], chunk)
117
+ await source.capture_frame(frame)
118
+ # Sleep to maintain real-time pace (synchronize with frame duration).
119
+ await asyncio.sleep(frame_duration)
120
+ # Disconnect from the room after streaming is finished.
121
+ await room.disconnect()
122
+ return None
123
+
124
+ # Run the streaming coroutine and wait for completion.
125
+ error = asyncio.run(stream_audio())
126
+ if error:
127
+ return [{"error": error}]
128
+ # Return a success status (audio is delivered via LiveKit, not in the HTTP response).
129
+ return [{"status": "success"}]