Spaces:
Running
Running
import asyncio | |
import base64 | |
import re | |
import tempfile | |
from typing import AsyncGenerator, Optional, List | |
import aiohttp | |
import numpy as np | |
import torchaudio as ta | |
from chatterbox.tts import ChatterboxTTS | |
from loguru import logger | |
from pydantic import BaseModel, Field | |
from pipecat.frames.frames import ( | |
ErrorFrame, | |
Frame, | |
TTSAudioRawFrame, | |
TTSStartedFrame, | |
TTSStoppedFrame, | |
) | |
from pipecat.services.tts_service import TTSService | |
from pipecat.transcriptions.language import Language | |
class ChatterboxTTSService(TTSService): | |
"""Text-to-Speech service using Chatterbox for on-device TTS. | |
This service uses Chatterbox to generate speech. It supports voice cloning | |
from an audio prompt. | |
""" | |
class InputParams(BaseModel): | |
"""Configuration parameters for Chatterbox TTS service.""" | |
audio_prompt: Optional[str] = Field( | |
None, description="URL or file path to an audio prompt for voice cloning." | |
) | |
exaggeration: float = Field(0.5, ge=0.0, le=1.0) | |
cfg: float = Field(0.5, ge=0.0, le=1.0) | |
temperature: float = Field(0.8, ge=0.0, le=1.0) | |
def __init__( | |
self, | |
*, | |
device: str = "cpu", | |
params: InputParams = InputParams(), | |
**kwargs, | |
): | |
"""Initialize Chatterbox TTS service. | |
Args: | |
device: The device to run the model on (e.g., "cpu", "cuda"). | |
params: Configuration parameters for TTS generation. | |
""" | |
super().__init__(**kwargs) | |
logger.info(f"Initializing Chatterbox TTS service on device: {device}") | |
self._model = ChatterboxTTS.from_pretrained(device=device) | |
self._sample_rate = self._model.sr | |
self._settings = params.dict() | |
self._temp_files: List[str] = [] | |
logger.info("Chatterbox TTS service initialized") | |
def __del__(self): | |
self._cleanup_temp_files() | |
def can_generate_metrics(self) -> bool: | |
return True | |
def language_to_service_language(self, language: Language) -> Optional[str]: | |
"""Returns the language code for Chatterbox TTS. Only English is supported.""" | |
if language.value.startswith("en"): | |
return "en" | |
logger.warning( | |
f"Chatterbox TTS only supports English, but got {language}. Defaulting to English." | |
) | |
return "en" | |
async def _handle_audio_prompt(self, audio_prompt: str) -> Optional[str]: | |
if re.match(r"^https?://", audio_prompt): | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get(audio_prompt) as resp: | |
resp.raise_for_status() | |
content = await resp.read() | |
tmp_file = tempfile.NamedTemporaryFile( | |
delete=False, suffix=".wav" | |
) | |
tmp_file.write(content) | |
tmp_file.close() | |
self._temp_files.append(tmp_file.name) | |
return tmp_file.name | |
except Exception as e: | |
logger.error(f"Error downloading audio prompt from URL: {e}") | |
return None | |
return audio_prompt | |
def _cleanup_temp_files(self): | |
import os | |
for temp_file in self._temp_files: | |
try: | |
if os.path.exists(temp_file): | |
os.unlink(temp_file) | |
except OSError as e: | |
logger.warning(f"Error cleaning up temp file {temp_file}: {e}") | |
self._temp_files.clear() | |
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: | |
"""Generate speech from text using Chatterbox.""" | |
logger.debug(f"Generating TTS for: [{text}]") | |
try: | |
await self.start_ttfb_metrics() | |
yield TTSStartedFrame() | |
audio_prompt_path = self._settings.get("audio_prompt") | |
if audio_prompt_path: | |
audio_prompt_path = await self._handle_audio_prompt(audio_prompt_path) | |
await self.start_tts_usage_metrics(text) | |
loop = asyncio.get_running_loop() | |
wav = await loop.run_in_executor( | |
None, | |
self._model.generate, | |
text, | |
audio_prompt_path, | |
self._settings["exaggeration"], | |
self._settings["cfg"], | |
self._settings["temperature"], | |
) | |
audio_data = (wav.cpu().numpy() * 32767).astype(np.int16).tobytes() | |
yield TTSAudioRawFrame( | |
audio=audio_data, | |
sample_rate=self._sample_rate, | |
num_channels=1, | |
) | |
yield TTSStoppedFrame() | |
except Exception as e: | |
logger.error(f"{self} exception: {e}", exc_info=True) | |
yield ErrorFrame(f"Error generating audio: {e}") | |
finally: | |
self._cleanup_temp_files() | |