File size: 3,782 Bytes
8362005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import asyncio
from queue import Queue
from threading import Thread
from typing import AsyncGenerator, List, Optional

from loguru import logger
from orpheus_tts import OrpheusModel
from pydantic import BaseModel, Field

from pipecat.frames.frames import (
    ErrorFrame,
    Frame,
    TTSAudioRawFrame,
    TTSStartedFrame,
    TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService


class OrpheusTTSService(TTSService):
    """TTS service for Orpheus.

    This service uses Orpheus to generate speech. It streams the audio chunks.
    """

    class InputParams(BaseModel):
        """Configuration parameters for Orpheus TTS service."""

        voice: str = Field("tara", description="Voice to use for generation.")
        repetition_penalty: Optional[float] = Field(1.1)
        stop_token_ids: Optional[List[int]] = Field([128258])
        max_tokens: Optional[int] = Field(2000)
        temperature: Optional[float] = Field(0.4)
        top_p: Optional[float] = Field(0.9)

    def __init__(
        self,
        *,
        model_name: str = "canopylabs/orpheus-tts-0.1-finetune-prod",
        sample_rate: int = 24000,
        params: InputParams = InputParams(),
        **kwargs,
    ):
        """Initialize Orpheus TTS service.

        Args:
            model_name: The name of the Orpheus model to use.
            sample_rate: The sample rate of the audio.
            params: Configuration parameters for TTS generation.
        """
        super().__init__(sample_rate=sample_rate, **kwargs)
        logger.info(f"Initializing Orpheus TTS service with model: {model_name}")
        self._model = OrpheusModel(model_name=model_name)
        self._settings = params.dict()
        logger.info("Orpheus TTS service initialized")

    def can_generate_metrics(self) -> bool:
        return True

    async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
        logger.debug(f"Generating TTS for: [{text}]")
        try:
            await self.start_ttfb_metrics()
            yield TTSStartedFrame()

            loop = asyncio.get_running_loop()
            q = Queue()

            def generate():
                try:
                    stream = self._model.generate_speech(
                        prompt=text,
                        voice=self._settings["voice"],
                        repetition_penalty=self._settings["repetition_penalty"],
                        stop_token_ids=self._settings["stop_token_ids"],
                        max_tokens=self._settings["max_tokens"],
                        temperature=self._settings["temperature"],
                        top_p=self._settings["top_p"],
                    )
                    for chunk in stream:
                        q.put(chunk)
                except Exception as e:
                    logger.error(
                        f"Error in Orpheus generate_speech thread: {e}", exc_info=True
                    )
                    q.put(e)
                finally:
                    q.put(None)  # Sentinel to indicate end of stream

            thread = Thread(target=generate)
            thread.start()

            await self.start_tts_usage_metrics(text)

            while True:
                item = await loop.run_in_executor(None, q.get)
                if isinstance(item, Exception):
                    raise item
                if item is None:
                    break

                yield TTSAudioRawFrame(
                    audio=item, sample_rate=self.sample_rate, num_channels=1
                )

            thread.join()

            yield TTSStoppedFrame()
        except Exception as e:
            logger.error(f"{self} exception: {e}", exc_info=True)
            yield ErrorFrame(f"Error generating audio: {str(e)}")