mriusero commited on
Commit
ab7f293
·
1 Parent(s): a9effa1

feat: thinking + streaming

Browse files
Files changed (2) hide show
  1. src/agent/inference.py +157 -124
  2. src/ui/sidebar.py +77 -3
src/agent/inference.py CHANGED
@@ -32,131 +32,164 @@ class MistralAgent:
32
  ]
33
  ).get('tools')
34
 
35
- def make_initial_request(self, input):
36
- """Make the initial request to the agent with the given input."""
37
- with open("./prompt.md", 'r', encoding='utf-8') as file:
38
- self.prompt = file.read()
39
- messages = [
40
- {"role": "system", "content": self.prompt},
41
- {"role": "user", "content": input},
42
- {
43
- "role": "assistant",
44
- "content": "THINKING:\nLet's tackle this problem, ",
45
- "prefix": True,
46
- },
47
- ]
48
- payload = {
49
- "agent_id": self.agent_id,
50
- "messages": messages,
51
- "max_tokens": None,
52
- "stream": False,
53
- "stop": None,
54
- "random_seed": None,
55
- "response_format": None,
56
- "tools": self.tools,
57
- "tool_choice": 'auto',
58
- "presence_penalty": 0,
59
- "frequency_penalty": 0,
60
- "n": 1,
61
- "prediction": None,
62
- "parallel_tool_calls": None
63
- }
64
- return self.client.agents.complete(**payload), messages
65
-
66
- def run(self, input):
67
- """Run the agent with the given input and process the response."""
68
- print("\n===== Asking the agent =====\n")
69
- response, messages = self.make_initial_request(input)
70
- first_iteration = True
71
-
72
- while True:
73
- time.sleep(1)
74
- if hasattr(response, 'choices') and response.choices:
75
- choice = response.choices[0]
76
-
77
- if first_iteration:
78
- messages = [message for message in messages if not message.get("prefix")]
79
- messages.append(
80
- {
81
- "role": "assistant",
82
- "content": choice.message.content,
83
- "prefix": True,
84
- },
85
- )
86
- first_iteration = False
87
- else:
88
- if choice.message.tool_calls:
89
- results = []
90
-
91
- for tool_call in choice.message.tool_calls:
92
- function_name = tool_call.function.name
93
- function_params = json.loads(tool_call.function.arguments)
94
-
95
- try:
96
- function_result = self.names_to_functions[function_name](**function_params)
97
- results.append((tool_call.id, function_name, function_result))
98
-
99
- except Exception as e:
100
- results.append((tool_call.id, function_name, None))
101
 
102
- for tool_call_id, function_name, function_result in results:
103
- messages.append({
104
- "role": "assistant",
105
- "tool_calls": [
106
- {
107
- "id": tool_call_id,
108
- "type": "function",
109
- "function": {
110
- "name": function_name,
111
- "arguments": json.dumps(function_params),
112
- }
113
- }
114
- ]
115
- })
116
- messages.append(
117
- {
118
- "role": "tool",
119
- "content": function_result if function_result is not None else f"Error occurred: {function_name} failed to execute",
120
- "tool_call_id": tool_call_id,
121
- },
122
- )
123
- for message in messages:
124
- if "prefix" in message:
125
- del message["prefix"]
126
- messages.append(
127
- {
128
- "role": "assistant",
129
- "content": f"Based on the results, ",
130
- "prefix": True,
131
- }
132
- )
133
- else:
134
- for message in messages:
135
- if "prefix" in message:
136
- del message["prefix"]
137
- messages.append(
138
- {
139
- "role": "assistant",
140
- "content": choice.message.content,
141
- }
142
- )
143
- if 'FINAL ANSWER:' in choice.message.content:
144
- print("\n===== END OF REQUEST =====\n", json.dumps(messages, indent=2))
145
- ans = choice.message.content.split('FINAL ANSWER:')[1].strip()
146
 
147
- timestamp = time.strftime("%Y%m%d-%H%M%S")
148
- output_file = f"chat_{timestamp}.json"
149
- with open(output_file, "w", encoding="utf-8") as f:
150
- json.dump(messages, f, indent=2, ensure_ascii=False)
151
- print(f"Conversation enregistrée dans {output_file}")
152
 
153
- return ans
154
 
155
- print("\n===== MESSAGES BEFORE API CALL =====\n", json.dumps(messages, indent=2))
156
- time.sleep(1)
157
- response = self.client.agents.complete(
158
- agent_id=self.agent_id,
159
- messages=messages,
160
- tools=self.tools,
161
- tool_choice='auto',
162
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ]
33
  ).get('tools')
34
 
35
+ #def make_initial_request(self, input):
36
+ # """Make the initial request to the agent with the given input."""
37
+ # with open("./prompt.md", 'r', encoding='utf-8') as file:
38
+ # self.prompt = file.read()
39
+ # messages = [
40
+ # {"role": "system", "content": self.prompt},
41
+ # {"role": "user", "content": input},
42
+ # {
43
+ # "role": "assistant",
44
+ # "content": "THINKING:\nLet's tackle this problem, ",
45
+ # "prefix": True,
46
+ # },
47
+ # ]
48
+ # payload = {
49
+ # "agent_id": self.agent_id,
50
+ # "messages": messages,
51
+ # "max_tokens": None,
52
+ # "stream": True,
53
+ # "stop": None,
54
+ # "random_seed": None,
55
+ # "response_format": None,
56
+ # "tools": self.tools,
57
+ # "tool_choice": 'auto',
58
+ # "presence_penalty": 0,
59
+ # "frequency_penalty": 0,
60
+ # "n": 1,
61
+ # "prediction": None,
62
+ # "parallel_tool_calls": None
63
+ # }
64
+ # stream = self.client.agents.complete(**payload)
65
+ # return stream, messages
66
+ #
67
+ #def run(self, input):
68
+ # """Run the agent with the given input and process the response."""
69
+ # print("\n===== Asking the agent =====\n")
70
+ # stream, messages = self.make_initial_request(input)
71
+ #
72
+ # for data in stream:
73
+ # # Si `stream` renvoie des chaînes brutes de type `data: {...}`
74
+ # if isinstance(data, str) and data.startswith("data: "):
75
+ # try:
76
+ # json_str = data[len("data: "):].strip()
77
+ # if json_str == "[DONE]":
78
+ # break
79
+ # chunk = json.loads(json_str)
80
+ # delta = chunk.get("choices", [{}])[0].get("delta", {})
81
+ # content = delta.get("content")
82
+ # if content:
83
+ # yield content
84
+ #
85
+ # # Fin de réponse
86
+ # if chunk["choices"][0].get("finish_reason") is not None:
87
+ # break
88
+ # except json.JSONDecodeError:
89
+ # continue
90
+ #
91
+ # # Si `stream` donne directement des dicts (selon ton client)
92
+ # elif isinstance(data, dict):
93
+ # delta = data.get("choices", [{}])[0].get("delta", {})
94
+ # content = delta.get("content")
95
+ # if content:
96
+ # yield content
97
+ #
98
+ # if data["choices"][0].get("finish_reason") is not None:
99
+ # break
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
 
 
 
102
 
103
+ #first_iteration = True
104
 
105
+ #while True:
106
+ # time.sleep(1)
107
+ # if hasattr(response, 'choices') and response.choices:
108
+ # choice = response.choices[0]
109
+ #
110
+ # if first_iteration:
111
+ # messages = [message for message in messages if not message.get("prefix")]
112
+ # messages.append(
113
+ # {
114
+ # "role": "assistant",
115
+ # "content": choice.message.content,
116
+ # "prefix": True,
117
+ # },
118
+ # )
119
+ # first_iteration = False
120
+ # else:
121
+ # if choice.message.tool_calls:
122
+ # results = []
123
+ #
124
+ # for tool_call in choice.message.tool_calls:
125
+ # function_name = tool_call.function.name
126
+ # function_params = json.loads(tool_call.function.arguments)
127
+ #
128
+ # try:
129
+ # function_result = self.names_to_functions[function_name](**function_params)
130
+ # results.append((tool_call.id, function_name, function_result))
131
+ #
132
+ # except Exception as e:
133
+ # results.append((tool_call.id, function_name, None))
134
+ #
135
+ # for tool_call_id, function_name, function_result in results:
136
+ # messages.append({
137
+ # "role": "assistant",
138
+ # "tool_calls": [
139
+ # {
140
+ # "id": tool_call_id,
141
+ # "type": "function",
142
+ # "function": {
143
+ # "name": function_name,
144
+ # "arguments": json.dumps(function_params),
145
+ # }
146
+ # }
147
+ # ]
148
+ # })
149
+ # messages.append(
150
+ # {
151
+ # "role": "tool",
152
+ # "content": function_result if function_result is not None else f"Error occurred: {function_name} failed to execute",
153
+ # "tool_call_id": tool_call_id,
154
+ # },
155
+ # )
156
+ # for message in messages:
157
+ # if "prefix" in message:
158
+ # del message["prefix"]
159
+ # messages.append(
160
+ # {
161
+ # "role": "assistant",
162
+ # "content": f"Based on the results, ",
163
+ # "prefix": True,
164
+ # }
165
+ # )
166
+ # else:
167
+ # for message in messages:
168
+ # if "prefix" in message:
169
+ # del message["prefix"]
170
+ # messages.append(
171
+ # {
172
+ # "role": "assistant",
173
+ # "content": choice.message.content,
174
+ # }
175
+ # )
176
+ # if 'FINAL ANSWER:' in choice.message.content:
177
+ # print("\n===== END OF REQUEST =====\n", json.dumps(messages, indent=2))
178
+ # ans = choice.message.content.split('FINAL ANSWER:')[1].strip()
179
+ #
180
+ # timestamp = time.strftime("%Y%m%d-%H%M%S")
181
+ # output_file = f"chat_{timestamp}.json"
182
+ # with open(output_file, "w", encoding="utf-8") as f:
183
+ # json.dump(messages, f, indent=2, ensure_ascii=False)
184
+ # print(f"Conversation enregistrée dans {output_file}")
185
+ #
186
+ # return ans
187
+ #
188
+ # print("\n===== MESSAGES BEFORE API CALL =====\n", json.dumps(messages, indent=2))
189
+ # time.sleep(1)
190
+ # response = self.client.agents.complete(
191
+ # agent_id=self.agent_id,
192
+ # messages=messages,
193
+ # tools=self.tools,
194
+ # tool_choice='auto',
195
+ # )
src/ui/sidebar.py CHANGED
@@ -1,10 +1,84 @@
1
  import gradio as gr
 
 
2
 
3
  from src.agent.inference import MistralAgent
4
 
5
- def respond(gr_message, history=None):
6
- agent = MistralAgent()
7
- yield agent.run(gr_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  def sidebar_ui(state, width=700, visible=True):
 
1
  import gradio as gr
2
+ import json
3
+ from gradio import ChatMessage
4
 
5
  from src.agent.inference import MistralAgent
6
 
7
+ agent = MistralAgent()
8
+
9
+ async def respond(message, history=None):
10
+
11
+ if history is None:
12
+ history = []
13
+ history.append(ChatMessage(role="user", content=message))
14
+
15
+ thinking_msg = ChatMessage(
16
+ role="assistant",
17
+ content="",
18
+ metadata={"title": "Thinking", "status": "pending"}
19
+ )
20
+ history.append(thinking_msg)
21
+ yield history
22
+
23
+ with open("./prompt.md", encoding="utf-8") as f:
24
+ prompt = f.read()
25
+
26
+ messages = [
27
+ {"role": "system", "content": prompt},
28
+ {"role": "user", "content": message},
29
+ #{
30
+ # "role": "assistant",
31
+ # "content": "THINKING:\nLet's tackle this problem",
32
+ ## "prefix": True
33
+ #},
34
+ ]
35
+ payload = {
36
+ "agent_id": agent.agent_id,
37
+ "messages": messages,
38
+ "stream": True,
39
+ "max_tokens": None,
40
+ "tools": agent.tools,
41
+ "tool_choice": "auto",
42
+ "presence_penalty": 0,
43
+ "frequency_penalty": 0,
44
+ "n": 1
45
+ }
46
+
47
+ response = await agent.client.agents.stream_async(**payload)
48
+
49
+ full = ""
50
+ thinking = ""
51
+ final = ""
52
+
53
+ async for chunk in response:
54
+ delta = chunk.data.choices[0].delta
55
+ content = delta.content or ""
56
+ full += content
57
+
58
+ if "FINAL ANSWER:" in full:
59
+ parts = full.split("FINAL ANSWER:", 1)
60
+ thinking = parts[0].replace("THINKING:", "").strip()
61
+ final = parts[1].strip()
62
+ else:
63
+ thinking = full.strip()
64
+ final = ""
65
+
66
+ history[-1] = ChatMessage(
67
+ role="assistant",
68
+ content=thinking,
69
+ metadata={"title": "Thinking", "status": "pending"}
70
+ )
71
+ yield history
72
+
73
+ history[-1] = ChatMessage(
74
+ role="assistant",
75
+ content=thinking,
76
+ metadata={"title": "Thinking", "status": "done"}
77
+ )
78
+
79
+ history.append(ChatMessage(role="assistant", content=final))
80
+ yield history
81
+
82
 
83
 
84
  def sidebar_ui(state, width=700, visible=True):