RSHVR commited on
Commit
9c2fe2f
·
verified ·
1 Parent(s): 488744e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -55
app.py CHANGED
@@ -1,68 +1,132 @@
1
- import os
2
  import gradio as gr
3
- from fastrtc import Stream, AdditionalOutputs
4
- from fastrtc_walkie_talkie import WalkieTalkie
5
 
6
- # Import your custom models
7
- from tts import tortoise_tts, TortoiseOptions
8
- from stt import whisper_stt
9
  import cohereAPI
10
 
11
- # Environment variables
12
- COHERE_API_KEY = os.getenv("COHERE_API_KEY")
13
- system_message = "You respond concisely, in about 15 words or less"
14
 
15
- # Initialize conversation history
16
  conversation_history = []
17
 
18
- # Create a handler function that uses both your custom models
19
- def response(audio):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  global conversation_history
21
 
22
- # Convert speech to text using your Whisper model
23
- user_message = whisper_stt.stt(audio)
 
 
 
24
 
25
- # Yield the transcription as additional output
26
- yield AdditionalOutputs(user_message)
27
-
28
- # Send text to Cohere API
29
- response_text, updated_history = cohereAPI.send_message(
30
- system_message,
31
- user_message,
32
- conversation_history,
33
- COHERE_API_KEY
34
- )
35
-
36
- # Update conversation history
37
- conversation_history = updated_history
38
-
39
- # Print the response for logging
40
- print(f"Assistant: {response_text}")
41
-
42
- # Use your TTS model to generate audio
43
- tts_options = TortoiseOptions(voice_preset="random")
44
-
45
- # Stream the audio response in chunks
46
- for chunk in tortoise_tts.stream_tts_sync(response_text, tts_options):
47
- yield chunk
48
 
49
- # Create the FastRTC stream with WalkieTalkie for turn detection
50
- stream = Stream(
51
- handler=WalkieTalkie(response), # Use WalkieTalkie instead of ReplyOnPause
52
- modality="audio",
53
- mode="send-receive",
54
- additional_outputs=[gr.Textbox(label="Transcription")],
55
- additional_outputs_handler=lambda old, new: new if old is None else f"{old}\nUser: {new}",
56
- ui_args={
57
- "title": "Voice Assistant (Walkie-Talkie Style)",
58
- "subtitle": "Say 'over' to finish your turn. For example, 'What's the weather like today? over.'"
59
- }
60
- )
61
 
62
- # Launch the Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if __name__ == "__main__":
64
- stream.ui.launch(
65
- server_name="0.0.0.0",
66
- share=False,
67
- show_error=True
68
- )
 
 
1
  import gradio as gr
2
+ import os
 
3
 
 
 
 
4
  import cohereAPI
5
 
 
 
 
6
 
7
+ # Conversation history storage
8
  conversation_history = []
9
 
10
+ # Model configurations
11
+ COHERE_MODELS = [
12
+ "command-a-03-2025",
13
+ "command-r7b-12-2024",
14
+ "command-r-plus-08-2024",
15
+ "command-r-08-2024",
16
+ "command-light",
17
+ "command-light-nightly",
18
+ "command",
19
+ "command-nightly"
20
+ ]
21
+
22
+ def update_model_choices(provider):
23
+ """Update model dropdown choices based on selected provider"""
24
+ if provider == "Cohere":
25
+ return gr.Dropdown(choices=COHERE_MODELS, value=COHERE_MODELS[0])
26
+ else:
27
+ return gr.Dropdown(choices=[], value=None)
28
+
29
+ def show_model_change_info(model_name):
30
+ """Show info modal when model is changed"""
31
+ if model_name:
32
+ gr.Info(f"picking up from here with {model_name}")
33
+ return model_name
34
+
35
+
36
+ def respond(message, history, model_name="command-a-03-2025"):
37
+ """Generate streaming response using Cohere API"""
38
  global conversation_history
39
 
40
+ # Get API key from environment
41
+ api_key = os.getenv('COHERE_API_KEY')
42
+ if not api_key:
43
+ yield "Error: COHERE_API_KEY environment variable not set"
44
+ return
45
 
46
+ # System message for the chatbot
47
+ system_message = """You are a helpful AI assistant. Provide concise but complete responses.
48
+ Be direct and to the point while ensuring you fully address the user's question or request.
49
+ Do not repeat the user's question in your response. Do not exceed 50 words."""
50
+
51
+ try:
52
+ # Use streaming function
53
+ partial_message = ""
54
+ for chunk in cohereAPI.send_message_stream(
55
+ system_message=system_message,
56
+ user_message=message,
57
+ conversation_history=conversation_history,
58
+ api_key=api_key,
59
+ model_name=model_name
60
+ ):
61
+ partial_message += chunk
62
+ yield partial_message
63
+ except Exception as e:
64
+ yield f"Error: {str(e)}"
 
 
 
 
65
 
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown("## Modular Chatbot")
 
 
 
 
 
 
 
 
 
 
68
 
69
+ with gr.Row():
70
+ with gr.Column(scale=2):
71
+ chat_interface = gr.ChatInterface(
72
+ fn=respond,
73
+ type="messages",
74
+ save_history=True
75
+ )
76
+ with gr.Accordion("Chat Settings", elem_id="chat_settings_group"):
77
+ with gr.Row():
78
+ with gr.Column(scale=3):
79
+ provider = gr.Dropdown(
80
+ info="Provider",
81
+ choices=["Cohere", "OpenAI", "Anthropic", "Google", "HuggingFace"],
82
+ value="Cohere",
83
+ elem_id="provider_dropdown",
84
+ interactive=True,
85
+ show_label=False
86
+ )
87
+ model = gr.Dropdown(
88
+ info="Model",
89
+ choices=COHERE_MODELS,
90
+ value=COHERE_MODELS[0],
91
+ elem_id="model_dropdown",
92
+ interactive=True,
93
+ show_label=False
94
+ )
95
+
96
+ # Set up event handler for provider change
97
+ provider.change(
98
+ fn=update_model_choices,
99
+ inputs=[provider],
100
+ outputs=[model]
101
+ )
102
+
103
+ # Set up event handler for model change
104
+ model.change(
105
+ fn=show_model_change_info,
106
+ inputs=[model],
107
+ outputs=[model]
108
+ )
109
+
110
+ with gr.Column(scale=1):
111
+ temperature = gr.Slider(
112
+ label="Temperature",
113
+ info="Higher values make output more creative",
114
+ minimum=0.0,
115
+ maximum=1.0,
116
+ value=0.7,
117
+ step=0.01,
118
+ elem_id="temperature_slider",
119
+ interactive=True,
120
+
121
+ )
122
+ max_tokens = gr.Textbox(
123
+ label="Max Tokens",
124
+ info="Higher values allow longer responses.",
125
+ value="8192",
126
+ elem_id="max_tokens_input",
127
+ interactive=True,
128
+ show_label=False,
129
+ )
130
+
131
  if __name__ == "__main__":
132
+ demo.launch()