File size: 4,983 Bytes
8362005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import asyncio
import base64
import re
import tempfile
from typing import AsyncGenerator, Optional, List

import aiohttp
import numpy as np
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from loguru import logger
from pydantic import BaseModel, Field

from pipecat.frames.frames import (
    ErrorFrame,
    Frame,
    TTSAudioRawFrame,
    TTSStartedFrame,
    TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language


class ChatterboxTTSService(TTSService):
    """Text-to-Speech service using Chatterbox for on-device TTS.

    This service uses Chatterbox to generate speech. It supports voice cloning
    from an audio prompt.
    """

    class InputParams(BaseModel):
        """Configuration parameters for Chatterbox TTS service."""

        audio_prompt: Optional[str] = Field(
            None, description="URL or file path to an audio prompt for voice cloning."
        )
        exaggeration: float = Field(0.5, ge=0.0, le=1.0)
        cfg: float = Field(0.5, ge=0.0, le=1.0)
        temperature: float = Field(0.8, ge=0.0, le=1.0)

    def __init__(
        self,
        *,
        device: str = "cpu",
        params: InputParams = InputParams(),
        **kwargs,
    ):
        """Initialize Chatterbox TTS service.

        Args:
            device: The device to run the model on (e.g., "cpu", "cuda").
            params: Configuration parameters for TTS generation.
        """
        super().__init__(**kwargs)
        logger.info(f"Initializing Chatterbox TTS service on device: {device}")
        self._model = ChatterboxTTS.from_pretrained(device=device)
        self._sample_rate = self._model.sr
        self._settings = params.dict()
        self._temp_files: List[str] = []
        logger.info("Chatterbox TTS service initialized")

    def __del__(self):
        self._cleanup_temp_files()

    def can_generate_metrics(self) -> bool:
        return True

    def language_to_service_language(self, language: Language) -> Optional[str]:
        """Returns the language code for Chatterbox TTS. Only English is supported."""
        if language.value.startswith("en"):
            return "en"
        logger.warning(
            f"Chatterbox TTS only supports English, but got {language}. Defaulting to English."
        )
        return "en"

    async def _handle_audio_prompt(self, audio_prompt: str) -> Optional[str]:
        if re.match(r"^https?://", audio_prompt):
            try:
                async with aiohttp.ClientSession() as session:
                    async with session.get(audio_prompt) as resp:
                        resp.raise_for_status()
                        content = await resp.read()
                        tmp_file = tempfile.NamedTemporaryFile(
                            delete=False, suffix=".wav"
                        )
                        tmp_file.write(content)
                        tmp_file.close()
                        self._temp_files.append(tmp_file.name)
                        return tmp_file.name
            except Exception as e:
                logger.error(f"Error downloading audio prompt from URL: {e}")
                return None
        return audio_prompt

    def _cleanup_temp_files(self):
        import os

        for temp_file in self._temp_files:
            try:
                if os.path.exists(temp_file):
                    os.unlink(temp_file)
            except OSError as e:
                logger.warning(f"Error cleaning up temp file {temp_file}: {e}")
        self._temp_files.clear()

    async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
        """Generate speech from text using Chatterbox."""
        logger.debug(f"Generating TTS for: [{text}]")

        try:
            await self.start_ttfb_metrics()
            yield TTSStartedFrame()

            audio_prompt_path = self._settings.get("audio_prompt")
            if audio_prompt_path:
                audio_prompt_path = await self._handle_audio_prompt(audio_prompt_path)

            await self.start_tts_usage_metrics(text)
            
            loop = asyncio.get_running_loop()
            wav = await loop.run_in_executor(
                None,
                self._model.generate,
                text,
                audio_prompt_path,
                self._settings["exaggeration"],
                self._settings["cfg"],
                self._settings["temperature"],
            )

            audio_data = (wav.cpu().numpy() * 32767).astype(np.int16).tobytes()
            yield TTSAudioRawFrame(
                audio=audio_data,
                sample_rate=self._sample_rate,
                num_channels=1,
            )

            yield TTSStoppedFrame()
        except Exception as e:
            logger.error(f"{self} exception: {e}", exc_info=True)
            yield ErrorFrame(f"Error generating audio: {e}")
        finally:
            self._cleanup_temp_files()