RSHVR commited on
Commit
0707b48
·
verified ·
1 Parent(s): dad23dc

Update cohereAPI.py

Browse files
Files changed (1) hide show
  1. cohereAPI.py +32 -14
cohereAPI.py CHANGED
@@ -11,7 +11,7 @@ def get_client(api_key):
11
  _client = cohere.ClientV2(api_key)
12
  return _client
13
 
14
- def send_message_stream(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
15
  """Stream response from Cohere API"""
16
  # Get or create the Cohere client
17
  co = get_client(api_key)
@@ -21,11 +21,19 @@ def send_message_stream(system_message, user_message, conversation_history, api_
21
  messages.extend(conversation_history)
22
  messages.append({"role": "user", "content": user_message})
23
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Send streaming request to Cohere
25
- stream = co.chat_stream(
26
- model=model_name,
27
- messages=messages
28
- )
29
 
30
  # Collect full response for history
31
  full_response = ""
@@ -37,10 +45,10 @@ def send_message_stream(system_message, user_message, conversation_history, api_
37
  full_response += text_chunk
38
  yield text_chunk
39
 
40
- async def send_message_stream_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
41
  """Async wrapper for streaming response from Cohere API"""
42
  def _sync_stream():
43
- return send_message_stream(system_message, user_message, conversation_history, api_key, model_name)
44
 
45
  # Run the synchronous generator in a thread
46
  loop = asyncio.get_event_loop()
@@ -72,7 +80,7 @@ async def send_message_stream_async(system_message, user_message, conversation_h
72
  yield chunk
73
 
74
 
75
- def send_message(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
76
  """Non-streaming version for backward compatibility"""
77
  # Get or create the Cohere client
78
  co = get_client(api_key)
@@ -82,18 +90,26 @@ def send_message(system_message, user_message, conversation_history, api_key, mo
82
  messages.extend(conversation_history)
83
  messages.append({"role": "user", "content": user_message})
84
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Send request to Cohere synchronously
86
- response = co.chat(
87
- model=model_name,
88
- messages=messages
89
- )
90
 
91
  # Get the response
92
  response_content = response.message.content[0].text
93
 
94
  return response_content
95
 
96
- async def send_message_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025"):
97
  """Async version using asyncio.to_thread"""
98
  return await asyncio.to_thread(
99
  send_message,
@@ -101,5 +117,7 @@ async def send_message_async(system_message, user_message, conversation_history,
101
  user_message,
102
  conversation_history,
103
  api_key,
104
- model_name
 
 
105
  )
 
11
  _client = cohere.ClientV2(api_key)
12
  return _client
13
 
14
+ def send_message_stream(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
15
  """Stream response from Cohere API"""
16
  # Get or create the Cohere client
17
  co = get_client(api_key)
 
21
  messages.extend(conversation_history)
22
  messages.append({"role": "user", "content": user_message})
23
 
24
+ # Prepare chat parameters
25
+ chat_params = {
26
+ "model": model_name,
27
+ "messages": messages,
28
+ "temperature": temperature
29
+ }
30
+
31
+ # Add max_tokens if specified
32
+ if max_tokens:
33
+ chat_params["max_tokens"] = int(max_tokens)
34
+
35
  # Send streaming request to Cohere
36
+ stream = co.chat_stream(**chat_params)
 
 
 
37
 
38
  # Collect full response for history
39
  full_response = ""
 
45
  full_response += text_chunk
46
  yield text_chunk
47
 
48
+ async def send_message_stream_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
49
  """Async wrapper for streaming response from Cohere API"""
50
  def _sync_stream():
51
+ return send_message_stream(system_message, user_message, conversation_history, api_key, model_name, temperature, max_tokens)
52
 
53
  # Run the synchronous generator in a thread
54
  loop = asyncio.get_event_loop()
 
80
  yield chunk
81
 
82
 
83
+ def send_message(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
84
  """Non-streaming version for backward compatibility"""
85
  # Get or create the Cohere client
86
  co = get_client(api_key)
 
90
  messages.extend(conversation_history)
91
  messages.append({"role": "user", "content": user_message})
92
 
93
+ # Prepare chat parameters
94
+ chat_params = {
95
+ "model": model_name,
96
+ "messages": messages,
97
+ "temperature": temperature
98
+ }
99
+
100
+ # Add max_tokens if specified
101
+ if max_tokens:
102
+ chat_params["max_tokens"] = int(max_tokens)
103
+
104
  # Send request to Cohere synchronously
105
+ response = co.chat(**chat_params)
 
 
 
106
 
107
  # Get the response
108
  response_content = response.message.content[0].text
109
 
110
  return response_content
111
 
112
+ async def send_message_async(system_message, user_message, conversation_history, api_key, model_name="command-a-03-2025", temperature=0.7, max_tokens=None):
113
  """Async version using asyncio.to_thread"""
114
  return await asyncio.to_thread(
115
  send_message,
 
117
  user_message,
118
  conversation_history,
119
  api_key,
120
+ model_name,
121
+ temperature,
122
+ max_tokens
123
  )