Spaces:
Sleeping
Sleeping
mriusero
commited on
Commit
·
ab7f293
1
Parent(s):
a9effa1
feat: thinking + streaming
Browse files- src/agent/inference.py +157 -124
- src/ui/sidebar.py +77 -3
src/agent/inference.py
CHANGED
@@ -32,131 +32,164 @@ class MistralAgent:
|
|
32 |
]
|
33 |
).get('tools')
|
34 |
|
35 |
-
def make_initial_request(self, input):
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
if
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
results.append((tool_call.id, function_name, None))
|
101 |
|
102 |
-
for tool_call_id, function_name, function_result in results:
|
103 |
-
messages.append({
|
104 |
-
"role": "assistant",
|
105 |
-
"tool_calls": [
|
106 |
-
{
|
107 |
-
"id": tool_call_id,
|
108 |
-
"type": "function",
|
109 |
-
"function": {
|
110 |
-
"name": function_name,
|
111 |
-
"arguments": json.dumps(function_params),
|
112 |
-
}
|
113 |
-
}
|
114 |
-
]
|
115 |
-
})
|
116 |
-
messages.append(
|
117 |
-
{
|
118 |
-
"role": "tool",
|
119 |
-
"content": function_result if function_result is not None else f"Error occurred: {function_name} failed to execute",
|
120 |
-
"tool_call_id": tool_call_id,
|
121 |
-
},
|
122 |
-
)
|
123 |
-
for message in messages:
|
124 |
-
if "prefix" in message:
|
125 |
-
del message["prefix"]
|
126 |
-
messages.append(
|
127 |
-
{
|
128 |
-
"role": "assistant",
|
129 |
-
"content": f"Based on the results, ",
|
130 |
-
"prefix": True,
|
131 |
-
}
|
132 |
-
)
|
133 |
-
else:
|
134 |
-
for message in messages:
|
135 |
-
if "prefix" in message:
|
136 |
-
del message["prefix"]
|
137 |
-
messages.append(
|
138 |
-
{
|
139 |
-
"role": "assistant",
|
140 |
-
"content": choice.message.content,
|
141 |
-
}
|
142 |
-
)
|
143 |
-
if 'FINAL ANSWER:' in choice.message.content:
|
144 |
-
print("\n===== END OF REQUEST =====\n", json.dumps(messages, indent=2))
|
145 |
-
ans = choice.message.content.split('FINAL ANSWER:')[1].strip()
|
146 |
|
147 |
-
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
148 |
-
output_file = f"chat_{timestamp}.json"
|
149 |
-
with open(output_file, "w", encoding="utf-8") as f:
|
150 |
-
json.dump(messages, f, indent=2, ensure_ascii=False)
|
151 |
-
print(f"Conversation enregistrée dans {output_file}")
|
152 |
|
153 |
-
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
]
|
33 |
).get('tools')
|
34 |
|
35 |
+
#def make_initial_request(self, input):
|
36 |
+
# """Make the initial request to the agent with the given input."""
|
37 |
+
# with open("./prompt.md", 'r', encoding='utf-8') as file:
|
38 |
+
# self.prompt = file.read()
|
39 |
+
# messages = [
|
40 |
+
# {"role": "system", "content": self.prompt},
|
41 |
+
# {"role": "user", "content": input},
|
42 |
+
# {
|
43 |
+
# "role": "assistant",
|
44 |
+
# "content": "THINKING:\nLet's tackle this problem, ",
|
45 |
+
# "prefix": True,
|
46 |
+
# },
|
47 |
+
# ]
|
48 |
+
# payload = {
|
49 |
+
# "agent_id": self.agent_id,
|
50 |
+
# "messages": messages,
|
51 |
+
# "max_tokens": None,
|
52 |
+
# "stream": True,
|
53 |
+
# "stop": None,
|
54 |
+
# "random_seed": None,
|
55 |
+
# "response_format": None,
|
56 |
+
# "tools": self.tools,
|
57 |
+
# "tool_choice": 'auto',
|
58 |
+
# "presence_penalty": 0,
|
59 |
+
# "frequency_penalty": 0,
|
60 |
+
# "n": 1,
|
61 |
+
# "prediction": None,
|
62 |
+
# "parallel_tool_calls": None
|
63 |
+
# }
|
64 |
+
# stream = self.client.agents.complete(**payload)
|
65 |
+
# return stream, messages
|
66 |
+
#
|
67 |
+
#def run(self, input):
|
68 |
+
# """Run the agent with the given input and process the response."""
|
69 |
+
# print("\n===== Asking the agent =====\n")
|
70 |
+
# stream, messages = self.make_initial_request(input)
|
71 |
+
#
|
72 |
+
# for data in stream:
|
73 |
+
# # Si `stream` renvoie des chaînes brutes de type `data: {...}`
|
74 |
+
# if isinstance(data, str) and data.startswith("data: "):
|
75 |
+
# try:
|
76 |
+
# json_str = data[len("data: "):].strip()
|
77 |
+
# if json_str == "[DONE]":
|
78 |
+
# break
|
79 |
+
# chunk = json.loads(json_str)
|
80 |
+
# delta = chunk.get("choices", [{}])[0].get("delta", {})
|
81 |
+
# content = delta.get("content")
|
82 |
+
# if content:
|
83 |
+
# yield content
|
84 |
+
#
|
85 |
+
# # Fin de réponse
|
86 |
+
# if chunk["choices"][0].get("finish_reason") is not None:
|
87 |
+
# break
|
88 |
+
# except json.JSONDecodeError:
|
89 |
+
# continue
|
90 |
+
#
|
91 |
+
# # Si `stream` donne directement des dicts (selon ton client)
|
92 |
+
# elif isinstance(data, dict):
|
93 |
+
# delta = data.get("choices", [{}])[0].get("delta", {})
|
94 |
+
# content = delta.get("content")
|
95 |
+
# if content:
|
96 |
+
# yield content
|
97 |
+
#
|
98 |
+
# if data["choices"][0].get("finish_reason") is not None:
|
99 |
+
# break
|
|
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
#first_iteration = True
|
104 |
|
105 |
+
#while True:
|
106 |
+
# time.sleep(1)
|
107 |
+
# if hasattr(response, 'choices') and response.choices:
|
108 |
+
# choice = response.choices[0]
|
109 |
+
#
|
110 |
+
# if first_iteration:
|
111 |
+
# messages = [message for message in messages if not message.get("prefix")]
|
112 |
+
# messages.append(
|
113 |
+
# {
|
114 |
+
# "role": "assistant",
|
115 |
+
# "content": choice.message.content,
|
116 |
+
# "prefix": True,
|
117 |
+
# },
|
118 |
+
# )
|
119 |
+
# first_iteration = False
|
120 |
+
# else:
|
121 |
+
# if choice.message.tool_calls:
|
122 |
+
# results = []
|
123 |
+
#
|
124 |
+
# for tool_call in choice.message.tool_calls:
|
125 |
+
# function_name = tool_call.function.name
|
126 |
+
# function_params = json.loads(tool_call.function.arguments)
|
127 |
+
#
|
128 |
+
# try:
|
129 |
+
# function_result = self.names_to_functions[function_name](**function_params)
|
130 |
+
# results.append((tool_call.id, function_name, function_result))
|
131 |
+
#
|
132 |
+
# except Exception as e:
|
133 |
+
# results.append((tool_call.id, function_name, None))
|
134 |
+
#
|
135 |
+
# for tool_call_id, function_name, function_result in results:
|
136 |
+
# messages.append({
|
137 |
+
# "role": "assistant",
|
138 |
+
# "tool_calls": [
|
139 |
+
# {
|
140 |
+
# "id": tool_call_id,
|
141 |
+
# "type": "function",
|
142 |
+
# "function": {
|
143 |
+
# "name": function_name,
|
144 |
+
# "arguments": json.dumps(function_params),
|
145 |
+
# }
|
146 |
+
# }
|
147 |
+
# ]
|
148 |
+
# })
|
149 |
+
# messages.append(
|
150 |
+
# {
|
151 |
+
# "role": "tool",
|
152 |
+
# "content": function_result if function_result is not None else f"Error occurred: {function_name} failed to execute",
|
153 |
+
# "tool_call_id": tool_call_id,
|
154 |
+
# },
|
155 |
+
# )
|
156 |
+
# for message in messages:
|
157 |
+
# if "prefix" in message:
|
158 |
+
# del message["prefix"]
|
159 |
+
# messages.append(
|
160 |
+
# {
|
161 |
+
# "role": "assistant",
|
162 |
+
# "content": f"Based on the results, ",
|
163 |
+
# "prefix": True,
|
164 |
+
# }
|
165 |
+
# )
|
166 |
+
# else:
|
167 |
+
# for message in messages:
|
168 |
+
# if "prefix" in message:
|
169 |
+
# del message["prefix"]
|
170 |
+
# messages.append(
|
171 |
+
# {
|
172 |
+
# "role": "assistant",
|
173 |
+
# "content": choice.message.content,
|
174 |
+
# }
|
175 |
+
# )
|
176 |
+
# if 'FINAL ANSWER:' in choice.message.content:
|
177 |
+
# print("\n===== END OF REQUEST =====\n", json.dumps(messages, indent=2))
|
178 |
+
# ans = choice.message.content.split('FINAL ANSWER:')[1].strip()
|
179 |
+
#
|
180 |
+
# timestamp = time.strftime("%Y%m%d-%H%M%S")
|
181 |
+
# output_file = f"chat_{timestamp}.json"
|
182 |
+
# with open(output_file, "w", encoding="utf-8") as f:
|
183 |
+
# json.dump(messages, f, indent=2, ensure_ascii=False)
|
184 |
+
# print(f"Conversation enregistrée dans {output_file}")
|
185 |
+
#
|
186 |
+
# return ans
|
187 |
+
#
|
188 |
+
# print("\n===== MESSAGES BEFORE API CALL =====\n", json.dumps(messages, indent=2))
|
189 |
+
# time.sleep(1)
|
190 |
+
# response = self.client.agents.complete(
|
191 |
+
# agent_id=self.agent_id,
|
192 |
+
# messages=messages,
|
193 |
+
# tools=self.tools,
|
194 |
+
# tool_choice='auto',
|
195 |
+
# )
|
src/ui/sidebar.py
CHANGED
@@ -1,10 +1,84 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
from src.agent.inference import MistralAgent
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
def sidebar_ui(state, width=700, visible=True):
|
|
|
1 |
import gradio as gr
|
2 |
+
import json
|
3 |
+
from gradio import ChatMessage
|
4 |
|
5 |
from src.agent.inference import MistralAgent
|
6 |
|
7 |
+
agent = MistralAgent()
|
8 |
+
|
9 |
+
async def respond(message, history=None):
|
10 |
+
|
11 |
+
if history is None:
|
12 |
+
history = []
|
13 |
+
history.append(ChatMessage(role="user", content=message))
|
14 |
+
|
15 |
+
thinking_msg = ChatMessage(
|
16 |
+
role="assistant",
|
17 |
+
content="",
|
18 |
+
metadata={"title": "Thinking", "status": "pending"}
|
19 |
+
)
|
20 |
+
history.append(thinking_msg)
|
21 |
+
yield history
|
22 |
+
|
23 |
+
with open("./prompt.md", encoding="utf-8") as f:
|
24 |
+
prompt = f.read()
|
25 |
+
|
26 |
+
messages = [
|
27 |
+
{"role": "system", "content": prompt},
|
28 |
+
{"role": "user", "content": message},
|
29 |
+
#{
|
30 |
+
# "role": "assistant",
|
31 |
+
# "content": "THINKING:\nLet's tackle this problem",
|
32 |
+
## "prefix": True
|
33 |
+
#},
|
34 |
+
]
|
35 |
+
payload = {
|
36 |
+
"agent_id": agent.agent_id,
|
37 |
+
"messages": messages,
|
38 |
+
"stream": True,
|
39 |
+
"max_tokens": None,
|
40 |
+
"tools": agent.tools,
|
41 |
+
"tool_choice": "auto",
|
42 |
+
"presence_penalty": 0,
|
43 |
+
"frequency_penalty": 0,
|
44 |
+
"n": 1
|
45 |
+
}
|
46 |
+
|
47 |
+
response = await agent.client.agents.stream_async(**payload)
|
48 |
+
|
49 |
+
full = ""
|
50 |
+
thinking = ""
|
51 |
+
final = ""
|
52 |
+
|
53 |
+
async for chunk in response:
|
54 |
+
delta = chunk.data.choices[0].delta
|
55 |
+
content = delta.content or ""
|
56 |
+
full += content
|
57 |
+
|
58 |
+
if "FINAL ANSWER:" in full:
|
59 |
+
parts = full.split("FINAL ANSWER:", 1)
|
60 |
+
thinking = parts[0].replace("THINKING:", "").strip()
|
61 |
+
final = parts[1].strip()
|
62 |
+
else:
|
63 |
+
thinking = full.strip()
|
64 |
+
final = ""
|
65 |
+
|
66 |
+
history[-1] = ChatMessage(
|
67 |
+
role="assistant",
|
68 |
+
content=thinking,
|
69 |
+
metadata={"title": "Thinking", "status": "pending"}
|
70 |
+
)
|
71 |
+
yield history
|
72 |
+
|
73 |
+
history[-1] = ChatMessage(
|
74 |
+
role="assistant",
|
75 |
+
content=thinking,
|
76 |
+
metadata={"title": "Thinking", "status": "done"}
|
77 |
+
)
|
78 |
+
|
79 |
+
history.append(ChatMessage(role="assistant", content=final))
|
80 |
+
yield history
|
81 |
+
|
82 |
|
83 |
|
84 |
def sidebar_ui(state, width=700, visible=True):
|