Update app.py
Browse files
app.py
CHANGED
@@ -110,11 +110,19 @@ class BasicAgent:
|
|
110 |
prompt.append(f"{role}: {text}")
|
111 |
return "\n".join(prompt)
|
112 |
|
113 |
-
def generate(self, question: str) -> str:
|
114 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
115 |
-
|
116 |
allowed = {"max_new_tokens", "temperature", "top_k", "top_p"}
|
117 |
gen_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
prompt_str = (
|
119 |
self._serialize_messages(prompt)
|
120 |
if isinstance(prompt, list)
|
|
|
110 |
prompt.append(f"{role}: {text}")
|
111 |
return "\n".join(prompt)
|
112 |
|
113 |
+
def generate(self, question: str, stop_sequences=None, **kwargs) -> str:
|
114 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
115 |
+
# 1. Build the HF kwargs
|
116 |
allowed = {"max_new_tokens", "temperature", "top_k", "top_p"}
|
117 |
gen_kwargs = {k: v for k, v in kwargs.items() if k in allowed}
|
118 |
+
|
119 |
+
# 2. Optionally map SmolAgents’ stop_sequences → HF pipeline’s 'stop'
|
120 |
+
if stop_sequences is not None:
|
121 |
+
# if your pipeline accepts `stop`, you can do:
|
122 |
+
gen_kwargs["stop"] = stop_sequences
|
123 |
+
# otherwise just ignore it
|
124 |
+
|
125 |
+
# 3. Serialize the message
|
126 |
prompt_str = (
|
127 |
self._serialize_messages(prompt)
|
128 |
if isinstance(prompt, list)
|