Update app.py
Browse files
app.py
CHANGED
@@ -116,13 +116,7 @@ class BasicAgent:
|
|
116 |
allowed = {"max_new_tokens", "temperature", "top_k", "top_p"}
|
117 |
gen_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
|
118 |
|
119 |
-
# 2.
|
120 |
-
if stop_sequences is not None:
|
121 |
-
# if your pipeline accepts `stop`, you can do:
|
122 |
-
gen_kwargs["stop"] = stop_sequences
|
123 |
-
# otherwise just ignore it
|
124 |
-
|
125 |
-
# 3. Serialize the message
|
126 |
prompt_str = (
|
127 |
self._serialize_messages(question)
|
128 |
if isinstance(question, list)
|
@@ -131,6 +125,14 @@ class BasicAgent:
|
|
131 |
outputs = self.pipe(prompt_str, **gen_kwargs)
|
132 |
response = outputs[0]["generated_text"]
|
133 |
# response = self.agent.run(question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
print(f"Agent returning its generated answer: {response}")
|
135 |
|
136 |
# wrap back into a chat message dict
|
|
|
116 |
allowed = {"max_new_tokens", "temperature", "top_k", "top_p"}
|
117 |
gen_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
|
118 |
|
119 |
+
# 2. Serialize the message and get the response
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
prompt_str = (
|
121 |
self._serialize_messages(question)
|
122 |
if isinstance(question, list)
|
|
|
125 |
outputs = self.pipe(prompt_str, **gen_kwargs)
|
126 |
response = outputs[0]["generated_text"]
|
127 |
# response = self.agent.run(question)
|
128 |
+
|
129 |
+
# 3. Optionally map SmolAgents’ stop_sequences → HF pipeline’s 'stop'
|
130 |
+
if stop_sequences:
|
131 |
+
# find the earliest occurrence of any stop token
|
132 |
+
cuts = [response.find(s) for s in stop_sequences if response.find(s) != -1]
|
133 |
+
if cuts:
|
134 |
+
response = response[: min(cuts)]
|
135 |
+
|
136 |
print(f"Agent returning its generated answer: {response}")
|
137 |
|
138 |
# wrap back into a chat message dict
|