klentyboopathi's picture
Intital commit
8362005
raw
history blame
3.78 kB
import asyncio
from queue import Queue
from threading import Thread
from typing import AsyncGenerator, List, Optional
from loguru import logger
from orpheus_tts import OrpheusModel
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
class OrpheusTTSService(TTSService):
"""TTS service for Orpheus.
This service uses Orpheus to generate speech. It streams the audio chunks.
"""
class InputParams(BaseModel):
"""Configuration parameters for Orpheus TTS service."""
voice: str = Field("tara", description="Voice to use for generation.")
repetition_penalty: Optional[float] = Field(1.1)
stop_token_ids: Optional[List[int]] = Field([128258])
max_tokens: Optional[int] = Field(2000)
temperature: Optional[float] = Field(0.4)
top_p: Optional[float] = Field(0.9)
def __init__(
self,
*,
model_name: str = "canopylabs/orpheus-tts-0.1-finetune-prod",
sample_rate: int = 24000,
params: InputParams = InputParams(),
**kwargs,
):
"""Initialize Orpheus TTS service.
Args:
model_name: The name of the Orpheus model to use.
sample_rate: The sample rate of the audio.
params: Configuration parameters for TTS generation.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
logger.info(f"Initializing Orpheus TTS service with model: {model_name}")
self._model = OrpheusModel(model_name=model_name)
self._settings = params.dict()
logger.info("Orpheus TTS service initialized")
def can_generate_metrics(self) -> bool:
return True
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS for: [{text}]")
try:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
loop = asyncio.get_running_loop()
q = Queue()
def generate():
try:
stream = self._model.generate_speech(
prompt=text,
voice=self._settings["voice"],
repetition_penalty=self._settings["repetition_penalty"],
stop_token_ids=self._settings["stop_token_ids"],
max_tokens=self._settings["max_tokens"],
temperature=self._settings["temperature"],
top_p=self._settings["top_p"],
)
for chunk in stream:
q.put(chunk)
except Exception as e:
logger.error(
f"Error in Orpheus generate_speech thread: {e}", exc_info=True
)
q.put(e)
finally:
q.put(None) # Sentinel to indicate end of stream
thread = Thread(target=generate)
thread.start()
await self.start_tts_usage_metrics(text)
while True:
item = await loop.run_in_executor(None, q.get)
if isinstance(item, Exception):
raise item
if item is None:
break
yield TTSAudioRawFrame(
audio=item, sample_rate=self.sample_rate, num_channels=1
)
thread.join()
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} exception: {e}", exc_info=True)
yield ErrorFrame(f"Error generating audio: {str(e)}")