klentyboopathi's picture
Intital commit
8362005
import asyncio
import base64
import re
import tempfile
from typing import AsyncGenerator, Optional, List
import aiohttp
import numpy as np
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from loguru import logger
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
class ChatterboxTTSService(TTSService):
"""Text-to-Speech service using Chatterbox for on-device TTS.
This service uses Chatterbox to generate speech. It supports voice cloning
from an audio prompt.
"""
class InputParams(BaseModel):
"""Configuration parameters for Chatterbox TTS service."""
audio_prompt: Optional[str] = Field(
None, description="URL or file path to an audio prompt for voice cloning."
)
exaggeration: float = Field(0.5, ge=0.0, le=1.0)
cfg: float = Field(0.5, ge=0.0, le=1.0)
temperature: float = Field(0.8, ge=0.0, le=1.0)
def __init__(
self,
*,
device: str = "cpu",
params: InputParams = InputParams(),
**kwargs,
):
"""Initialize Chatterbox TTS service.
Args:
device: The device to run the model on (e.g., "cpu", "cuda").
params: Configuration parameters for TTS generation.
"""
super().__init__(**kwargs)
logger.info(f"Initializing Chatterbox TTS service on device: {device}")
self._model = ChatterboxTTS.from_pretrained(device=device)
self._sample_rate = self._model.sr
self._settings = params.dict()
self._temp_files: List[str] = []
logger.info("Chatterbox TTS service initialized")
def __del__(self):
self._cleanup_temp_files()
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Returns the language code for Chatterbox TTS. Only English is supported."""
if language.value.startswith("en"):
return "en"
logger.warning(
f"Chatterbox TTS only supports English, but got {language}. Defaulting to English."
)
return "en"
async def _handle_audio_prompt(self, audio_prompt: str) -> Optional[str]:
if re.match(r"^https?://", audio_prompt):
try:
async with aiohttp.ClientSession() as session:
async with session.get(audio_prompt) as resp:
resp.raise_for_status()
content = await resp.read()
tmp_file = tempfile.NamedTemporaryFile(
delete=False, suffix=".wav"
)
tmp_file.write(content)
tmp_file.close()
self._temp_files.append(tmp_file.name)
return tmp_file.name
except Exception as e:
logger.error(f"Error downloading audio prompt from URL: {e}")
return None
return audio_prompt
def _cleanup_temp_files(self):
import os
for temp_file in self._temp_files:
try:
if os.path.exists(temp_file):
os.unlink(temp_file)
except OSError as e:
logger.warning(f"Error cleaning up temp file {temp_file}: {e}")
self._temp_files.clear()
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Chatterbox."""
logger.debug(f"Generating TTS for: [{text}]")
try:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
audio_prompt_path = self._settings.get("audio_prompt")
if audio_prompt_path:
audio_prompt_path = await self._handle_audio_prompt(audio_prompt_path)
await self.start_tts_usage_metrics(text)
loop = asyncio.get_running_loop()
wav = await loop.run_in_executor(
None,
self._model.generate,
text,
audio_prompt_path,
self._settings["exaggeration"],
self._settings["cfg"],
self._settings["temperature"],
)
audio_data = (wav.cpu().numpy() * 32767).astype(np.int16).tobytes()
yield TTSAudioRawFrame(
audio=audio_data,
sample_rate=self._sample_rate,
num_channels=1,
)
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} exception: {e}", exc_info=True)
yield ErrorFrame(f"Error generating audio: {e}")
finally:
self._cleanup_temp_files()