from __future__ import annotations import asyncio from .._run_impl import TraceCtxManager from ..exceptions import UserError from ..logger import logger from .input import AudioInput, StreamedAudioInput from .model import STTModel, TTSModel from .pipeline_config import VoicePipelineConfig from .result import StreamedAudioResult from .workflow import VoiceWorkflowBase class VoicePipeline: """An opinionated voice agent pipeline. It works in three steps: 1. Transcribe audio input into text. 2. Run the provided `workflow`, which produces a sequence of text responses. 3. Convert the text responses into streaming audio output. """ def __init__( self, *, workflow: VoiceWorkflowBase, stt_model: STTModel | str | None = None, tts_model: TTSModel | str | None = None, config: VoicePipelineConfig | None = None, ): """Create a new voice pipeline. Args: workflow: The workflow to run. See `VoiceWorkflowBase`. stt_model: The speech-to-text model to use. If not provided, a default OpenAI model will be used. tts_model: The text-to-speech model to use. If not provided, a default OpenAI model will be used. config: The pipeline configuration. If not provided, a default configuration will be used. """ self.workflow = workflow self.stt_model = stt_model if isinstance(stt_model, STTModel) else None self.tts_model = tts_model if isinstance(tts_model, TTSModel) else None self._stt_model_name = stt_model if isinstance(stt_model, str) else None self._tts_model_name = tts_model if isinstance(tts_model, str) else None self.config = config or VoicePipelineConfig() async def run(self, audio_input: AudioInput | StreamedAudioInput) -> StreamedAudioResult: """Run the voice pipeline. Args: audio_input: The audio input to process. This can either be an `AudioInput` instance, which is a single static buffer, or a `StreamedAudioInput` instance, which is a stream of audio data that you can append to. Returns: A `StreamedAudioResult` instance. You can use this object to stream audio events and play them out. """ if isinstance(audio_input, AudioInput): return await self._run_single_turn(audio_input) elif isinstance(audio_input, StreamedAudioInput): return await self._run_multi_turn(audio_input) else: raise UserError(f"Unsupported audio input type: {type(audio_input)}") def _get_tts_model(self) -> TTSModel: if not self.tts_model: self.tts_model = self.config.model_provider.get_tts_model(self._tts_model_name) return self.tts_model def _get_stt_model(self) -> STTModel: if not self.stt_model: self.stt_model = self.config.model_provider.get_stt_model(self._stt_model_name) return self.stt_model async def _process_audio_input(self, audio_input: AudioInput) -> str: model = self._get_stt_model() return await model.transcribe( audio_input, self.config.stt_settings, self.config.trace_include_sensitive_data, self.config.trace_include_sensitive_audio_data, ) async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult: # Since this is single turn, we can use the TraceCtxManager to manage starting/ending the # trace with TraceCtxManager( workflow_name=self.config.workflow_name or "Voice Agent", trace_id=None, # Automatically generated group_id=self.config.group_id, metadata=self.config.trace_metadata, disabled=self.config.tracing_disabled, ): input_text = await self._process_audio_input(audio_input) output = StreamedAudioResult( self._get_tts_model(), self.config.tts_settings, self.config ) async def stream_events(): try: async for text_event in self.workflow.run(input_text): await output._add_text(text_event) await output._turn_done() await output._done() except Exception as e: logger.error(f"Error processing single turn: {e}") await output._add_error(e) raise e output._set_task(asyncio.create_task(stream_events())) return output async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult: with TraceCtxManager( workflow_name=self.config.workflow_name or "Voice Agent", trace_id=None, group_id=self.config.group_id, metadata=self.config.trace_metadata, disabled=self.config.tracing_disabled, ): output = StreamedAudioResult( self._get_tts_model(), self.config.tts_settings, self.config ) transcription_session = await self._get_stt_model().create_session( audio_input, self.config.stt_settings, self.config.trace_include_sensitive_data, self.config.trace_include_sensitive_audio_data, ) async def process_turns(): try: async for input_text in transcription_session.transcribe_turns(): result = self.workflow.run(input_text) async for text_event in result: await output._add_text(text_event) await output._turn_done() except Exception as e: logger.error(f"Error processing turns: {e}") await output._add_error(e) raise e finally: await transcription_session.close() await output._done() output._set_task(asyncio.create_task(process_turns())) return output