Spaces:
Sleeping
Sleeping
import cohere | |
import asyncio | |
# Global client variable for lazy initialization | |
_client = None | |
def get_client(api_key): | |
"""Get or create Cohere client instance""" | |
global _client | |
if _client is None: | |
_client = cohere.ClientV2(api_key) | |
return _client | |
def send_message_stream(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None): | |
"""Stream response from Cohere API""" | |
# Get or create the Cohere client | |
co = get_client(api_key) | |
# Prepare all messages including history | |
messages = [{"role": "system", "content": system_message}] | |
messages.extend(conversation_history) | |
messages.append({"role": "user", "content": user_message}) | |
# Prepare chat parameters | |
chat_params = { | |
"model": model_name, | |
"messages": messages, | |
"temperature": temperature | |
} | |
# Add max_tokens if specified | |
if max_tokens: | |
chat_params["max_tokens"] = int(max_tokens) | |
# Send streaming request to Cohere | |
stream = co.chat_stream(**chat_params) | |
# Collect full response for history | |
full_response = "" | |
# Yield chunks as they come | |
for chunk in stream: | |
if chunk.type == "content-delta": | |
text_chunk = chunk.delta.message.content.text | |
full_response += text_chunk | |
yield text_chunk | |
async def send_message_stream_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None): | |
"""Async wrapper for streaming response from Cohere API""" | |
def _sync_stream(): | |
return send_message_stream(system_message, user_message, conversation_history, api_key, model_name, temperature, max_tokens) | |
# Run the synchronous generator in a thread | |
loop = asyncio.get_event_loop() | |
# Use a queue to handle the streaming data | |
queue = asyncio.Queue() | |
def _stream_worker(): | |
try: | |
for chunk in _sync_stream(): | |
loop.call_soon_threadsafe(queue.put_nowait, chunk) | |
except Exception as e: | |
loop.call_soon_threadsafe(queue.put_nowait, StopIteration(e)) | |
else: | |
loop.call_soon_threadsafe(queue.put_nowait, StopIteration()) | |
# Start the worker thread | |
import threading | |
thread = threading.Thread(target=_stream_worker) | |
thread.start() | |
# Yield chunks asynchronously | |
while True: | |
chunk = await queue.get() | |
if isinstance(chunk, StopIteration): | |
if chunk.args: | |
raise chunk.args[0] | |
break | |
yield chunk | |
def send_message(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None): | |
"""Non-streaming version for backward compatibility""" | |
# Get or create the Cohere client | |
co = get_client(api_key) | |
# Prepare all messages including history | |
messages = [{"role": "system", "content": system_message}] | |
messages.extend(conversation_history) | |
messages.append({"role": "user", "content": user_message}) | |
# Prepare chat parameters | |
chat_params = { | |
"model": model_name, | |
"messages": messages, | |
"temperature": temperature | |
} | |
# Add max_tokens if specified | |
if max_tokens: | |
chat_params["max_tokens"] = int(max_tokens) | |
# Send request to Cohere synchronously | |
response = co.chat(**chat_params) | |
# Get the response | |
response_content = response.message.content[0].text | |
return response_content | |
async def send_message_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None): | |
"""Async version using asyncio.to_thread""" | |
return await asyncio.to_thread( | |
send_message, | |
system_message, | |
user_message, | |
conversation_history, | |
api_key, | |
model_name, | |
temperature, | |
max_tokens | |
) |