ConradLinus's picture
Upload folder using huggingface_hub
d631808 verified
from __future__ import annotations
import asyncio
from .._run_impl import TraceCtxManager
from ..exceptions import UserError
from ..logger import logger
from .input import AudioInput, StreamedAudioInput
from .model import STTModel, TTSModel
from .pipeline_config import VoicePipelineConfig
from .result import StreamedAudioResult
from .workflow import VoiceWorkflowBase
class VoicePipeline:
"""An opinionated voice agent pipeline. It works in three steps:
1. Transcribe audio input into text.
2. Run the provided `workflow`, which produces a sequence of text responses.
3. Convert the text responses into streaming audio output.
"""
def __init__(
self,
*,
workflow: VoiceWorkflowBase,
stt_model: STTModel | str | None = None,
tts_model: TTSModel | str | None = None,
config: VoicePipelineConfig | None = None,
):
"""Create a new voice pipeline.
Args:
workflow: The workflow to run. See `VoiceWorkflowBase`.
stt_model: The speech-to-text model to use. If not provided, a default OpenAI
model will be used.
tts_model: The text-to-speech model to use. If not provided, a default OpenAI
model will be used.
config: The pipeline configuration. If not provided, a default configuration will be
used.
"""
self.workflow = workflow
self.stt_model = stt_model if isinstance(stt_model, STTModel) else None
self.tts_model = tts_model if isinstance(tts_model, TTSModel) else None
self._stt_model_name = stt_model if isinstance(stt_model, str) else None
self._tts_model_name = tts_model if isinstance(tts_model, str) else None
self.config = config or VoicePipelineConfig()
async def run(self, audio_input: AudioInput | StreamedAudioInput) -> StreamedAudioResult:
"""Run the voice pipeline.
Args:
audio_input: The audio input to process. This can either be an `AudioInput` instance,
which is a single static buffer, or a `StreamedAudioInput` instance, which is a
stream of audio data that you can append to.
Returns:
A `StreamedAudioResult` instance. You can use this object to stream audio events and
play them out.
"""
if isinstance(audio_input, AudioInput):
return await self._run_single_turn(audio_input)
elif isinstance(audio_input, StreamedAudioInput):
return await self._run_multi_turn(audio_input)
else:
raise UserError(f"Unsupported audio input type: {type(audio_input)}")
def _get_tts_model(self) -> TTSModel:
if not self.tts_model:
self.tts_model = self.config.model_provider.get_tts_model(self._tts_model_name)
return self.tts_model
def _get_stt_model(self) -> STTModel:
if not self.stt_model:
self.stt_model = self.config.model_provider.get_stt_model(self._stt_model_name)
return self.stt_model
async def _process_audio_input(self, audio_input: AudioInput) -> str:
model = self._get_stt_model()
return await model.transcribe(
audio_input,
self.config.stt_settings,
self.config.trace_include_sensitive_data,
self.config.trace_include_sensitive_audio_data,
)
async def _run_single_turn(self, audio_input: AudioInput) -> StreamedAudioResult:
# Since this is single turn, we can use the TraceCtxManager to manage starting/ending the
# trace
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None, # Automatically generated
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
disabled=self.config.tracing_disabled,
):
input_text = await self._process_audio_input(audio_input)
output = StreamedAudioResult(
self._get_tts_model(), self.config.tts_settings, self.config
)
async def stream_events():
try:
async for text_event in self.workflow.run(input_text):
await output._add_text(text_event)
await output._turn_done()
await output._done()
except Exception as e:
logger.error(f"Error processing single turn: {e}")
await output._add_error(e)
raise e
output._set_task(asyncio.create_task(stream_events()))
return output
async def _run_multi_turn(self, audio_input: StreamedAudioInput) -> StreamedAudioResult:
with TraceCtxManager(
workflow_name=self.config.workflow_name or "Voice Agent",
trace_id=None,
group_id=self.config.group_id,
metadata=self.config.trace_metadata,
disabled=self.config.tracing_disabled,
):
output = StreamedAudioResult(
self._get_tts_model(), self.config.tts_settings, self.config
)
transcription_session = await self._get_stt_model().create_session(
audio_input,
self.config.stt_settings,
self.config.trace_include_sensitive_data,
self.config.trace_include_sensitive_audio_data,
)
async def process_turns():
try:
async for input_text in transcription_session.transcribe_turns():
result = self.workflow.run(input_text)
async for text_event in result:
await output._add_text(text_event)
await output._turn_done()
except Exception as e:
logger.error(f"Error processing turns: {e}")
await output._add_error(e)
raise e
finally:
await transcription_session.close()
await output._done()
output._set_task(asyncio.create_task(process_turns()))
return output