klentyboopathi commited on
Commit
7b20468
·
1 Parent(s): 32b2887
Files changed (1) hide show
  1. bot/bot_websocket_server.py +97 -125
bot/bot_websocket_server.py CHANGED
@@ -66,128 +66,100 @@ load_dotenv(override=True)
66
 
67
 
68
  async def run_bot_websocket_server(websocket_client):
69
- ws_transport = FastAPIWebsocketTransport(
70
- websocket=websocket_client,
71
- params=FastAPIWebsocketParams(
72
- audio_in_enabled=True,
73
- audio_out_enabled=True,
74
- add_wav_header=False,
75
- vad_analyzer=SileroVADAnalyzer(),
76
- serializer=ProtobufFrameSerializer(),
77
- ),
78
- )
79
-
80
- stt = WhisperSTTService(
81
- model="tiny",
82
- device="cpu",
83
- compute_type="default",
84
- language="en",
85
- )
86
-
87
- llm = OLLamaLLMService(
88
- model="smollm:latest",
89
- # params=OLLamaLLMService.InputParams(temperature=0.7, max_tokens=1000),
90
- )
91
-
92
- # TTS = FishAudioTTSService(
93
- # api_key=os.getenv("CARTESIA_API_KEY"),
94
- # voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Reading Lady
95
- # )
96
- # async with aiohttp.ClientSession() as session:
97
- # TTS = XTTSService(
98
- # voice_id="speaker_1",
99
- # language=Language.EN,
100
- # base_url="http://localhost:8000",
101
- # aiohttp_session=session
102
- # )
103
-
104
- context = OpenAILLMContext(
105
- [
106
- {
107
- "role": "system",
108
- "content": SYSTEM_INSTRUCTION,
109
- },
110
- {
111
- "role": "user",
112
- "content": "Start by greeting the user warmly and introducing yourself.",
113
- },
114
- ],
115
- )
116
- context_aggregator = llm.create_context_aggregator(context)
117
-
118
- # RTVI events for Pipecat client UI
119
- rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
120
-
121
- TTS = KokoroTTSService(
122
- model_path=os.path.join(
123
- os.path.dirname(__file__), "assets", "kokoro-v1.0.int8.onnx"
124
- ),
125
- voices_path=os.path.join(os.path.dirname(__file__), "assets", "voices.json"),
126
- voice_id="af",
127
- sample_rate=16000,
128
- )
129
-
130
- # TTS = OrpheusTTSService(
131
- # model_name="canopylabs/orpheus-3b-0.1-ft",
132
- # sample_rate=16000,
133
- # )
134
-
135
- # TTS = ChatterboxTTSService(
136
- # model_name="",
137
- # sample_rate=16000,
138
- # )
139
-
140
- # TTS = DiaTTSService(
141
- # model_name="nari-labs/Dia-1.6B",
142
- # sample_rate=16000,
143
- # )
144
- pipeline = Pipeline(
145
- [
146
- ws_transport.input(),
147
- rtvi,
148
- stt, # STT
149
- context_aggregator.user(),
150
- llm,
151
- TTS, # TTS
152
- ws_transport.output(),
153
- context_aggregator.assistant(),
154
- ]
155
- )
156
-
157
- task = PipelineTask(
158
- pipeline,
159
- params=PipelineParams(
160
- enable_metrics=True,
161
- allow_interruptions=True,
162
- enable_usage_metrics=True,
163
- ),
164
- # enable_turn_tracking=True,
165
- enable_tracing=False,
166
- conversation_id="test",
167
- observers=[RTVIObserver(rtvi)],
168
- )
169
-
170
- @rtvi.event_handler("on_client_ready")
171
- async def on_client_ready(rtvi):
172
- logger.info("Pipecat client ready.")
173
- await rtvi.set_bot_ready()
174
- # Kick off the conversation.
175
- await task.queue_frames([context_aggregator.user().get_context_frame()])
176
-
177
- @ws_transport.event_handler("on_client_connected")
178
- async def on_client_connected(transport, client):
179
- logger.info("Pipecat Client connected")
180
-
181
- @ws_transport.event_handler("on_client_disconnected")
182
- async def on_client_disconnected(transport, client):
183
- logger.info("Pipecat Client disconnected")
184
- await task.cancel()
185
-
186
- @ws_transport.event_handler("on_session_timeout")
187
- async def on_session_timeout(transport, client):
188
- logger.info(f"Entering in timeout for {client.remote_address}")
189
- await task.cancel()
190
-
191
- runner = PipelineRunner()
192
-
193
- await runner.run(task)
 
66
 
67
 
68
  async def run_bot_websocket_server(websocket_client):
69
+ try:
70
+ ws_transport = FastAPIWebsocketTransport(
71
+ websocket=websocket_client,
72
+ params=FastAPIWebsocketParams(
73
+ audio_in_enabled=True,
74
+ audio_out_enabled=True,
75
+ add_wav_header=False,
76
+ vad_analyzer=SileroVADAnalyzer(),
77
+ serializer=ProtobufFrameSerializer(),
78
+ ),
79
+ )
80
+
81
+ stt = WhisperSTTService(
82
+ model="tiny",
83
+ device="cpu",
84
+ compute_type="default",
85
+ language="en",
86
+ )
87
+
88
+ llm = OLLamaLLMService(
89
+ model="smollm:latest",
90
+ )
91
+
92
+ context = OpenAILLMContext(
93
+ [
94
+ {"role": "system", "content": SYSTEM_INSTRUCTION},
95
+ {
96
+ "role": "user",
97
+ "content": "Start by greeting the user warmly and introducing yourself.",
98
+ },
99
+ ]
100
+ )
101
+ context_aggregator = llm.create_context_aggregator(context)
102
+
103
+ rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
104
+
105
+ TTS = KokoroTTSService(
106
+ model_path=os.path.join(
107
+ os.path.dirname(__file__), "assets", "kokoro-v1.0.int8.onnx"
108
+ ),
109
+ voices_path=os.path.join(
110
+ os.path.dirname(__file__), "assets", "voices.json"
111
+ ),
112
+ voice_id="af",
113
+ sample_rate=16000,
114
+ )
115
+
116
+ pipeline = Pipeline(
117
+ [
118
+ ws_transport.input(),
119
+ rtvi,
120
+ stt,
121
+ context_aggregator.user(),
122
+ llm,
123
+ TTS,
124
+ ws_transport.output(),
125
+ context_aggregator.assistant(),
126
+ ]
127
+ )
128
+
129
+ task = PipelineTask(
130
+ pipeline,
131
+ params=PipelineParams(
132
+ enable_metrics=True,
133
+ allow_interruptions=True,
134
+ enable_usage_metrics=True,
135
+ ),
136
+ enable_tracing=False,
137
+ conversation_id="test",
138
+ observers=[RTVIObserver(rtvi)],
139
+ )
140
+
141
+ @rtvi.event_handler("on_client_ready")
142
+ async def on_client_ready(rtvi):
143
+ logger.info("Pipecat client ready.")
144
+ await rtvi.set_bot_ready()
145
+ await task.queue_frames([context_aggregator.user().get_context_frame()])
146
+
147
+ @ws_transport.event_handler("on_client_connected")
148
+ async def on_client_connected(transport, client):
149
+ logger.info("Pipecat Client connected")
150
+
151
+ @ws_transport.event_handler("on_client_disconnected")
152
+ async def on_client_disconnected(transport, client):
153
+ logger.info("Pipecat Client disconnected")
154
+ await task.cancel()
155
+
156
+ @ws_transport.event_handler("on_session_timeout")
157
+ async def on_session_timeout(transport, client):
158
+ logger.info(f"Entering in timeout for {client.remote_address}")
159
+ await task.cancel()
160
+
161
+ runner = PipelineRunner()
162
+ await runner.run(task)
163
+
164
+ except Exception as e:
165
+ logger.exception("Error in run_bot_websocket_server")