import asyncio from typing import AsyncGenerator import numpy as np import torch from dia.model import Dia from loguru import logger from pydantic import BaseModel, Field from pipecat.frames.frames import ( ErrorFrame, Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame, ) from pipecat.services.tts_service import TTSService class DiaTTSService(TTSService): """TTS service for Dia. This service uses Dia to generate speech. It does not support streaming and will generate the entire audio at once. """ class InputParams(BaseModel): """Configuration parameters for Dia TTS service.""" use_torch_compile: bool = Field(False) verbose: bool = Field(False) def __init__( self, *, model_name: str = "nari-labs/Dia-1.6B", compute_dtype: str = "float32", device: str = "cpu", sample_rate: int = 24000, params: InputParams = InputParams(), **kwargs, ): """Initialize Dia TTS service.""" super().__init__(sample_rate=sample_rate, **kwargs) logger.info(f"Initializing Dia TTS service with model: {model_name}") torch_device = torch.device(device) self._model = Dia.from_pretrained( model_name, compute_dtype=compute_dtype, device=torch_device ) self._settings = params.dict() logger.info("Dia TTS service initialized") def can_generate_metrics(self) -> bool: return True async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: logger.debug(f"Generating TTS for: [{text}]") try: await self.start_ttfb_metrics() yield TTSStartedFrame() loop = asyncio.get_running_loop() await self.start_tts_usage_metrics(text) output = await loop.run_in_executor( None, self._model.generate, text, self._settings["use_torch_compile"], self._settings["verbose"], ) audio_tensor = output["audio_tensor"] # The tensor is float32 in range [-1, 1], shape (1, N). # Convert to int16 bytes for pipecat. audio_data = (audio_tensor.cpu().numpy() * 32767).astype(np.int16).tobytes() yield TTSAudioRawFrame( audio=audio_data, sample_rate=self.sample_rate, num_channels=1, ) yield TTSStoppedFrame() except Exception as e: logger.error(f"{self} exception: {e}", exc_info=True) yield ErrorFrame(f"Error generating audio: {str(e)}")