RSHVR commited on
Commit
cb66cb4
·
verified ·
1 Parent(s): d8d34de

Add concurrency

Browse files
Files changed (1) hide show
  1. cohereAPI.py +60 -12
cohereAPI.py CHANGED
@@ -1,10 +1,20 @@
1
  import cohere
2
  import asyncio
3
 
 
 
 
 
 
 
 
 
 
 
4
  def send_message_stream(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
5
  """Stream response from Cohere API"""
6
- # Initialize the Cohere client
7
- co = cohere.ClientV2(api_key)
8
 
9
  # Prepare all messages including history
10
  messages = [{"role": "system", "content": system_message}]
@@ -26,15 +36,46 @@ def send_message_stream(system_message, user_message, conversation_history, api_
26
  text_chunk = chunk.delta.message.content.text
27
  full_response += text_chunk
28
  yield text_chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- # Update conversation history after streaming is complete
31
- conversation_history.append({"role": "user", "content": user_message})
32
- conversation_history.append({"role": "assistant", "content": full_response})
33
 
34
  def send_message(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
35
  """Non-streaming version for backward compatibility"""
36
- # Initialize the Cohere client
37
- co = cohere.ClientV2(api_key)
38
 
39
  # Prepare all messages including history
40
  messages = [{"role": "system", "content": system_message}]
@@ -50,8 +91,15 @@ def send_message(system_message, user_message, conversation_history, api_key, mo
50
  # Get the response
51
  response_content = response.message.content[0].text
52
 
53
- # Update conversation history for this session
54
- conversation_history.append({"role": "user", "content": user_message})
55
- conversation_history.append({"role": "assistant", "content": response_content})
56
-
57
- return response_content, conversation_history
 
 
 
 
 
 
 
 
1
  import cohere
2
  import asyncio
3
 
4
+ # Global client variable for lazy initialization
5
+ _client = None
6
+
7
+ def get_client(api_key):
8
+ """Get or create Cohere client instance"""
9
+ global _client
10
+ if _client is None:
11
+ _client = cohere.ClientV2(api_key)
12
+ return _client
13
+
14
  def send_message_stream(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
15
  """Stream response from Cohere API"""
16
+ # Get or create the Cohere client
17
+ co = get_client(api_key)
18
 
19
  # Prepare all messages including history
20
  messages = [{"role": "system", "content": system_message}]
 
36
  text_chunk = chunk.delta.message.content.text
37
  full_response += text_chunk
38
  yield text_chunk
39
+
40
+ async def send_message_stream_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
41
+ """Async wrapper for streaming response from Cohere API"""
42
+ def _sync_stream():
43
+ return send_message_stream(system_message, user_message, conversation_history, api_key, model_name)
44
+
45
+ # Run the synchronous generator in a thread
46
+ loop = asyncio.get_event_loop()
47
+
48
+ # Use a queue to handle the streaming data
49
+ queue = asyncio.Queue()
50
+
51
+ def _stream_worker():
52
+ try:
53
+ for chunk in _sync_stream():
54
+ loop.call_soon_threadsafe(queue.put_nowait, chunk)
55
+ except Exception as e:
56
+ loop.call_soon_threadsafe(queue.put_nowait, StopIteration(e))
57
+ else:
58
+ loop.call_soon_threadsafe(queue.put_nowait, StopIteration())
59
+
60
+ # Start the worker thread
61
+ import threading
62
+ thread = threading.Thread(target=_stream_worker)
63
+ thread.start()
64
+
65
+ # Yield chunks asynchronously
66
+ while True:
67
+ chunk = await queue.get()
68
+ if isinstance(chunk, StopIteration):
69
+ if chunk.args:
70
+ raise chunk.args[0]
71
+ break
72
+ yield chunk
73
 
 
 
 
74
 
75
  def send_message(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
76
  """Non-streaming version for backward compatibility"""
77
+ # Get or create the Cohere client
78
+ co = get_client(api_key)
79
 
80
  # Prepare all messages including history
81
  messages = [{"role": "system", "content": system_message}]
 
91
  # Get the response
92
  response_content = response.message.content[0].text
93
 
94
+ return response_content
95
+
96
+ async def send_message_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
97
+ """Async version using asyncio.to_thread"""
98
+ return await asyncio.to_thread(
99
+ send_message,
100
+ system_message,
101
+ user_message,
102
+ conversation_history,
103
+ api_key,
104
+ model_name
105
+ )