Spaces:
Running
Running
File size: 4,983 Bytes
8362005 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import asyncio
import base64
import re
import tempfile
from typing import AsyncGenerator, Optional, List
import aiohttp
import numpy as np
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from loguru import logger
from pydantic import BaseModel, Field
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
class ChatterboxTTSService(TTSService):
"""Text-to-Speech service using Chatterbox for on-device TTS.
This service uses Chatterbox to generate speech. It supports voice cloning
from an audio prompt.
"""
class InputParams(BaseModel):
"""Configuration parameters for Chatterbox TTS service."""
audio_prompt: Optional[str] = Field(
None, description="URL or file path to an audio prompt for voice cloning."
)
exaggeration: float = Field(0.5, ge=0.0, le=1.0)
cfg: float = Field(0.5, ge=0.0, le=1.0)
temperature: float = Field(0.8, ge=0.0, le=1.0)
def __init__(
self,
*,
device: str = "cpu",
params: InputParams = InputParams(),
**kwargs,
):
"""Initialize Chatterbox TTS service.
Args:
device: The device to run the model on (e.g., "cpu", "cuda").
params: Configuration parameters for TTS generation.
"""
super().__init__(**kwargs)
logger.info(f"Initializing Chatterbox TTS service on device: {device}")
self._model = ChatterboxTTS.from_pretrained(device=device)
self._sample_rate = self._model.sr
self._settings = params.dict()
self._temp_files: List[str] = []
logger.info("Chatterbox TTS service initialized")
def __del__(self):
self._cleanup_temp_files()
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Returns the language code for Chatterbox TTS. Only English is supported."""
if language.value.startswith("en"):
return "en"
logger.warning(
f"Chatterbox TTS only supports English, but got {language}. Defaulting to English."
)
return "en"
async def _handle_audio_prompt(self, audio_prompt: str) -> Optional[str]:
if re.match(r"^https?://", audio_prompt):
try:
async with aiohttp.ClientSession() as session:
async with session.get(audio_prompt) as resp:
resp.raise_for_status()
content = await resp.read()
tmp_file = tempfile.NamedTemporaryFile(
delete=False, suffix=".wav"
)
tmp_file.write(content)
tmp_file.close()
self._temp_files.append(tmp_file.name)
return tmp_file.name
except Exception as e:
logger.error(f"Error downloading audio prompt from URL: {e}")
return None
return audio_prompt
def _cleanup_temp_files(self):
import os
for temp_file in self._temp_files:
try:
if os.path.exists(temp_file):
os.unlink(temp_file)
except OSError as e:
logger.warning(f"Error cleaning up temp file {temp_file}: {e}")
self._temp_files.clear()
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Chatterbox."""
logger.debug(f"Generating TTS for: [{text}]")
try:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
audio_prompt_path = self._settings.get("audio_prompt")
if audio_prompt_path:
audio_prompt_path = await self._handle_audio_prompt(audio_prompt_path)
await self.start_tts_usage_metrics(text)
loop = asyncio.get_running_loop()
wav = await loop.run_in_executor(
None,
self._model.generate,
text,
audio_prompt_path,
self._settings["exaggeration"],
self._settings["cfg"],
self._settings["temperature"],
)
audio_data = (wav.cpu().numpy() * 32767).astype(np.int16).tobytes()
yield TTSAudioRawFrame(
audio=audio_data,
sample_rate=self._sample_rate,
num_channels=1,
)
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} exception: {e}", exc_info=True)
yield ErrorFrame(f"Error generating audio: {e}")
finally:
self._cleanup_temp_files()
|