Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import cohereAPI | |
# Model configurations | |
COHERE_MODELS = [ | |
"command-a-03-2025", | |
"command-r7b-12-2024", | |
"command-r-plus-08-2024", | |
"command-r-08-2024", | |
"command-light", | |
"command-light-nightly", | |
"command", | |
"command-nightly" | |
] | |
def update_model_choices(provider): | |
"""Update model dropdown choices based on selected provider""" | |
if provider == "Cohere": | |
return gr.Dropdown(choices=COHERE_MODELS, value=COHERE_MODELS[0]) | |
else: | |
return gr.Dropdown(choices=[], value=None) | |
def show_model_change_info(model_name): | |
"""Show info modal when model is changed""" | |
if model_name: | |
gr.Info(f"picking up from here with {model_name}") | |
return model_name | |
async def respond(message, history, model_name="command-a-03-2025", temperature=0.7, max_tokens=None): | |
"""Generate streaming response using Cohere API""" | |
# Convert Gradio history format to API format | |
conversation_history = [] | |
if history: | |
for entry in history: | |
if isinstance(entry, dict): | |
# Clean dict format - only keep role and content | |
if "role" in entry and "content" in entry: | |
conversation_history.append({ | |
"role": entry["role"], | |
"content": entry["content"] | |
}) | |
elif isinstance(entry, (list, tuple)) and len(entry) == 2: | |
# Old format: [user_msg, assistant_msg] | |
user_msg, assistant_msg = entry | |
if user_msg: | |
conversation_history.append({"role": "user", "content": str(user_msg)}) | |
if assistant_msg: | |
conversation_history.append({"role": "assistant", "content": str(assistant_msg)}) | |
else: | |
# Handle other formats gracefully | |
continue | |
# Get API key from environment | |
api_key = os.getenv('COHERE_API_KEY') | |
if not api_key: | |
yield "Error: COHERE_API_KEY environment variable not set" | |
return | |
# System message for the chatbot | |
system_message = """You are a helpful AI assistant. Provide concise but complete responses. | |
Be direct and to the point while ensuring you fully address the user's question or request. | |
Do not repeat the user's question in your response. Do not exceed 50 words.""" | |
try: | |
# Use async streaming function | |
partial_message = "" | |
async for chunk in cohereAPI.send_message_stream_async( | |
system_message=system_message, | |
user_message=message, | |
conversation_history=conversation_history, | |
api_key=api_key, | |
model_name=model_name, | |
temperature=temperature, | |
max_tokens=max_tokens | |
): | |
partial_message += chunk | |
yield partial_message | |
except Exception as e: | |
yield f"Error: {str(e)}" | |
with gr.Blocks() as demo: | |
gr.Markdown("""## Modular TTS-Chatbot | |
Status: In Development | |
The goal of this project is to enable voice-chat with any supported LLM which currently do not have speech ability similar to Gemini or GPT-4o. | |
""") | |
# State components to track current values | |
temperature_state = gr.State(value=0.7) | |
max_tokens_state = gr.State(value=None) | |
model_state = gr.State(value=COHERE_MODELS[0]) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Define wrapper function after all components are created | |
async def chat_wrapper(message, history, model_val, temp_val, tokens_val): | |
# Use the state values directly | |
current_model = model_val if model_val else COHERE_MODELS[0] | |
current_temp = temp_val if temp_val is not None else 0.7 | |
current_max_tokens = tokens_val | |
# Stream the response | |
async for chunk in respond(message, history, current_model, current_temp, current_max_tokens): | |
yield chunk | |
# Create chat interface using the wrapper with additional inputs | |
chat_interface = gr.ChatInterface( | |
fn=chat_wrapper, | |
type="messages", | |
save_history=True, | |
additional_inputs=[model_state, temperature_state, max_tokens_state] | |
) | |
with gr.Accordion("Chat Settings", elem_id="chat_settings_group"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
provider = gr.Dropdown( | |
info="Provider", | |
choices=["Cohere", "OpenAI", "Anthropic", "Google", "HuggingFace"], | |
value="Cohere", | |
elem_id="provider_dropdown", | |
interactive=True, | |
show_label=False | |
) | |
model = gr.Dropdown( | |
info="Model", | |
choices=COHERE_MODELS, | |
value=COHERE_MODELS[0], | |
elem_id="model_dropdown", | |
interactive=True, | |
show_label=False | |
) | |
# Set up event handler for provider change | |
provider.change( | |
fn=update_model_choices, | |
inputs=[provider], | |
outputs=[model] | |
) | |
# Set up event handler for model change | |
model.change( | |
fn=show_model_change_info, | |
inputs=[model], | |
outputs=[model] | |
) | |
# Update state when model changes | |
model.change( | |
fn=lambda x: x, | |
inputs=[model], | |
outputs=[model_state] | |
) | |
with gr.Column(scale=1): | |
temperature = gr.Slider( | |
label="Temperature", | |
info="Higher values make output more creative", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.7, | |
step=0.01, | |
elem_id="temperature_slider", | |
interactive=True, | |
) | |
max_tokens = gr.Textbox( | |
label="Max Tokens", | |
info="Higher values allow longer responses. Leave empty for default.", | |
value="8192", | |
elem_id="max_tokens_input", | |
interactive=True, | |
show_label=True, | |
) | |
# Update state when temperature changes | |
temperature.change( | |
fn=lambda x: x, | |
inputs=[temperature], | |
outputs=[temperature_state] | |
) | |
# Update state when max_tokens changes | |
max_tokens.change( | |
fn=lambda x: int(x) if x and str(x).strip() else None, | |
inputs=[max_tokens], | |
outputs=[max_tokens_state] | |
) | |
if __name__ == "__main__": | |
demo.launch() |