Spaces:
Running
Running
import gradio as gr | |
from gradio.components import Component | |
from functools import partial | |
from src.webui.webui_manager import WebuiManager | |
from src.utils import config | |
import logging | |
import os | |
from typing import Any, Dict, AsyncGenerator, Optional, Tuple, Union | |
import asyncio | |
import json | |
from src.agent.deep_research.deep_research_agent import DeepResearchAgent | |
from src.utils import llm_provider | |
logger = logging.getLogger(__name__) | |
async def _initialize_llm(provider: Optional[str], model_name: Optional[str], temperature: float, | |
base_url: Optional[str], api_key: Optional[str], num_ctx: Optional[int] = None): | |
"""Initializes the LLM based on settings. Returns None if provider/model is missing.""" | |
if not provider or not model_name: | |
logger.info("LLM Provider or Model Name not specified, LLM will be None.") | |
return None | |
try: | |
logger.info(f"Initializing LLM: Provider={provider}, Model={model_name}, Temp={temperature}") | |
# Use your actual LLM provider logic here | |
llm = llm_provider.get_llm_model( | |
provider=provider, | |
model_name=model_name, | |
temperature=temperature, | |
base_url=base_url or None, | |
api_key=api_key or None, | |
num_ctx=num_ctx if provider == "ollama" else None | |
) | |
return llm | |
except Exception as e: | |
logger.error(f"Failed to initialize LLM: {e}", exc_info=True) | |
gr.Warning( | |
f"Failed to initialize LLM '{model_name}' for provider '{provider}'. Please check settings. Error: {e}") | |
return None | |
def _read_file_safe(file_path: str) -> Optional[str]: | |
"""Safely read a file, returning None if it doesn't exist or on error.""" | |
if not os.path.exists(file_path): | |
return None | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
return f.read() | |
except Exception as e: | |
logger.error(f"Error reading file {file_path}: {e}") | |
return None | |
# --- Deep Research Agent Specific Logic --- | |
async def run_deep_research(webui_manager: WebuiManager, components: Dict[Component, Any]) -> AsyncGenerator[ | |
Dict[Component, Any], None]: | |
"""Handles initializing and running the DeepResearchAgent.""" | |
# --- Get Components --- | |
research_task_comp = webui_manager.get_component_by_id("deep_research_agent.research_task") | |
resume_task_id_comp = webui_manager.get_component_by_id("deep_research_agent.resume_task_id") | |
parallel_num_comp = webui_manager.get_component_by_id("deep_research_agent.parallel_num") | |
save_dir_comp = webui_manager.get_component_by_id( | |
"deep_research_agent.max_query") # Note: component ID seems misnamed in original code | |
start_button_comp = webui_manager.get_component_by_id("deep_research_agent.start_button") | |
stop_button_comp = webui_manager.get_component_by_id("deep_research_agent.stop_button") | |
markdown_display_comp = webui_manager.get_component_by_id("deep_research_agent.markdown_display") | |
markdown_download_comp = webui_manager.get_component_by_id("deep_research_agent.markdown_download") | |
mcp_server_config_comp = webui_manager.get_component_by_id("deep_research_agent.mcp_server_config") | |
# --- 1. Get Task and Settings --- | |
task_topic = components.get(research_task_comp, "").strip() | |
task_id_to_resume = components.get(resume_task_id_comp, "").strip() or None | |
max_parallel_agents = int(components.get(parallel_num_comp, 1)) | |
base_save_dir = components.get(save_dir_comp, "./tmp/deep_research") | |
mcp_server_config_str = components.get(mcp_server_config_comp) | |
mcp_config = json.loads(mcp_server_config_str) if mcp_server_config_str else None | |
if not task_topic: | |
gr.Warning("Please enter a research task.") | |
yield {start_button_comp: gr.update(interactive=True)} # Re-enable start button | |
return | |
# Store base save dir for stop handler | |
webui_manager.dr_save_dir = base_save_dir | |
os.makedirs(base_save_dir, exist_ok=True) | |
# --- 2. Initial UI Update --- | |
yield { | |
start_button_comp: gr.update(value="⏳ Running...", interactive=False), | |
stop_button_comp: gr.update(interactive=True), | |
research_task_comp: gr.update(interactive=False), | |
resume_task_id_comp: gr.update(interactive=False), | |
parallel_num_comp: gr.update(interactive=False), | |
save_dir_comp: gr.update(interactive=False), | |
markdown_display_comp: gr.update(value="Starting research..."), | |
markdown_download_comp: gr.update(value=None, interactive=False) | |
} | |
agent_task = None | |
running_task_id = None | |
plan_file_path = None | |
report_file_path = None | |
last_plan_content = None | |
last_plan_mtime = 0 | |
try: | |
# --- 3. Get LLM and Browser Config from other tabs --- | |
# Access settings values via components dict, getting IDs from webui_manager | |
def get_setting(tab: str, key: str, default: Any = None): | |
comp = webui_manager.id_to_component.get(f"{tab}.{key}") | |
return components.get(comp, default) if comp else default | |
# LLM Config (from agent_settings tab) | |
llm_provider_name = get_setting("agent_settings", "llm_provider") | |
llm_model_name = get_setting("agent_settings", "llm_model_name") | |
llm_temperature = max(get_setting("agent_settings", "llm_temperature", 0.5), 0.5) | |
llm_base_url = get_setting("agent_settings", "llm_base_url") | |
llm_api_key = get_setting("agent_settings", "llm_api_key") | |
ollama_num_ctx = get_setting("agent_settings", "ollama_num_ctx") | |
llm = await _initialize_llm( | |
llm_provider_name, llm_model_name, llm_temperature, llm_base_url, llm_api_key, | |
ollama_num_ctx if llm_provider_name == "ollama" else None | |
) | |
if not llm: | |
raise ValueError("LLM Initialization failed. Please check Agent Settings.") | |
# Browser Config (from browser_settings tab) | |
# Note: DeepResearchAgent constructor takes a dict, not full Browser/Context objects | |
browser_config_dict = { | |
"headless": get_setting("browser_settings", "headless", False), | |
"disable_security": get_setting("browser_settings", "disable_security", False), | |
"browser_binary_path": get_setting("browser_settings", "browser_binary_path"), | |
"user_data_dir": get_setting("browser_settings", "browser_user_data_dir"), | |
"window_width": int(get_setting("browser_settings", "window_w", 1280)), | |
"window_height": int(get_setting("browser_settings", "window_h", 1100)), | |
# Add other relevant fields if DeepResearchAgent accepts them | |
} | |
# --- 4. Initialize or Get Agent --- | |
if not webui_manager.dr_agent: | |
webui_manager.dr_agent = DeepResearchAgent( | |
llm=llm, | |
browser_config=browser_config_dict, | |
mcp_server_config=mcp_config | |
) | |
logger.info("DeepResearchAgent initialized.") | |
# --- 5. Start Agent Run --- | |
agent_run_coro = webui_manager.dr_agent.run( | |
topic=task_topic, | |
task_id=task_id_to_resume, | |
save_dir=base_save_dir, | |
max_parallel_browsers=max_parallel_agents | |
) | |
agent_task = asyncio.create_task(agent_run_coro) | |
webui_manager.dr_current_task = agent_task | |
# Wait briefly for the agent to start and potentially create the task ID/folder | |
await asyncio.sleep(1.0) | |
# Determine the actual task ID being used (agent sets this) | |
running_task_id = webui_manager.dr_agent.current_task_id | |
if not running_task_id: | |
# Agent might not have set it yet, try to get from result later? Risky. | |
# Or derive from resume_task_id if provided? | |
running_task_id = task_id_to_resume | |
if not running_task_id: | |
logger.warning("Could not determine running task ID immediately.") | |
# We can still monitor, but might miss initial plan if ID needed for path | |
else: | |
logger.info(f"Assuming task ID based on resume ID: {running_task_id}") | |
else: | |
logger.info(f"Agent started with Task ID: {running_task_id}") | |
webui_manager.dr_task_id = running_task_id # Store for stop handler | |
# --- 6. Monitor Progress via research_plan.md --- | |
if running_task_id: | |
task_specific_dir = os.path.join(base_save_dir, str(running_task_id)) | |
plan_file_path = os.path.join(task_specific_dir, "research_plan.md") | |
report_file_path = os.path.join(task_specific_dir, "report.md") | |
logger.info(f"Monitoring plan file: {plan_file_path}") | |
else: | |
logger.warning("Cannot monitor plan file: Task ID unknown.") | |
plan_file_path = None | |
last_plan_content = None | |
while not agent_task.done(): | |
update_dict = {} | |
update_dict[resume_task_id_comp] = gr.update(value=running_task_id) | |
agent_stopped = getattr(webui_manager.dr_agent, 'stopped', False) | |
if agent_stopped: | |
logger.info("Stop signal detected from agent state.") | |
break # Exit monitoring loop | |
# Check and update research plan display | |
if plan_file_path: | |
try: | |
current_mtime = os.path.getmtime(plan_file_path) if os.path.exists(plan_file_path) else 0 | |
if current_mtime > last_plan_mtime: | |
logger.info(f"Detected change in {plan_file_path}") | |
plan_content = _read_file_safe(plan_file_path) | |
if last_plan_content is None or ( | |
plan_content is not None and plan_content != last_plan_content): | |
update_dict[markdown_display_comp] = gr.update(value=plan_content) | |
last_plan_content = plan_content | |
last_plan_mtime = current_mtime | |
elif plan_content is None: | |
# File might have been deleted or became unreadable | |
last_plan_mtime = 0 # Reset to force re-read attempt later | |
except Exception as e: | |
logger.warning(f"Error checking/reading plan file {plan_file_path}: {e}") | |
# Avoid continuous logging for the same error | |
await asyncio.sleep(2.0) | |
# Yield updates if any | |
if update_dict: | |
yield update_dict | |
await asyncio.sleep(1.0) # Check file changes every second | |
# --- 7. Task Finalization --- | |
logger.info("Agent task processing finished. Awaiting final result...") | |
final_result_dict = await agent_task # Get result or raise exception | |
logger.info(f"Agent run completed. Result keys: {final_result_dict.keys() if final_result_dict else 'None'}") | |
# Try to get task ID from result if not known before | |
if not running_task_id and final_result_dict and 'task_id' in final_result_dict: | |
running_task_id = final_result_dict['task_id'] | |
webui_manager.dr_task_id = running_task_id | |
task_specific_dir = os.path.join(base_save_dir, str(running_task_id)) | |
report_file_path = os.path.join(task_specific_dir, "report.md") | |
logger.info(f"Task ID confirmed from result: {running_task_id}") | |
final_ui_update = {} | |
if report_file_path and os.path.exists(report_file_path): | |
logger.info(f"Loading final report from: {report_file_path}") | |
report_content = _read_file_safe(report_file_path) | |
if report_content: | |
final_ui_update[markdown_display_comp] = gr.update(value=report_content) | |
final_ui_update[markdown_download_comp] = gr.File(value=report_file_path, | |
label=f"Report ({running_task_id}.md)", | |
interactive=True) | |
else: | |
final_ui_update[markdown_display_comp] = gr.update( | |
value="# Research Complete\n\n*Error reading final report file.*") | |
elif final_result_dict and 'report' in final_result_dict: | |
logger.info("Using report content directly from agent result.") | |
# If agent directly returns report content | |
final_ui_update[markdown_display_comp] = gr.update(value=final_result_dict['report']) | |
# Cannot offer download if only content is available | |
final_ui_update[markdown_download_comp] = gr.update(value=None, label="Download Research Report", | |
interactive=False) | |
else: | |
logger.warning("Final report file not found and not in result dict.") | |
final_ui_update[markdown_display_comp] = gr.update(value="# Research Complete\n\n*Final report not found.*") | |
yield final_ui_update | |
except Exception as e: | |
logger.error(f"Error during Deep Research Agent execution: {e}", exc_info=True) | |
gr.Error(f"Research failed: {e}") | |
yield {markdown_display_comp: gr.update(value=f"# Research Failed\n\n**Error:**\n```\n{e}\n```")} | |
finally: | |
# --- 8. Final UI Reset --- | |
webui_manager.dr_current_task = None # Clear task reference | |
webui_manager.dr_task_id = None # Clear running task ID | |
yield { | |
start_button_comp: gr.update(value="▶️ Run", interactive=True), | |
stop_button_comp: gr.update(interactive=False), | |
research_task_comp: gr.update(interactive=True), | |
resume_task_id_comp: gr.update(value="", interactive=True), | |
parallel_num_comp: gr.update(interactive=True), | |
save_dir_comp: gr.update(interactive=True), | |
# Keep download button enabled if file exists | |
markdown_download_comp: gr.update() if report_file_path and os.path.exists(report_file_path) else gr.update( | |
interactive=False) | |
} | |
async def stop_deep_research(webui_manager: WebuiManager) -> Dict[Component, Any]: | |
"""Handles the Stop button click.""" | |
logger.info("Stop button clicked for Deep Research.") | |
agent = webui_manager.dr_agent | |
task = webui_manager.dr_current_task | |
task_id = webui_manager.dr_task_id | |
base_save_dir = webui_manager.dr_save_dir | |
stop_button_comp = webui_manager.get_component_by_id("deep_research_agent.stop_button") | |
start_button_comp = webui_manager.get_component_by_id("deep_research_agent.start_button") | |
markdown_display_comp = webui_manager.get_component_by_id("deep_research_agent.markdown_display") | |
markdown_download_comp = webui_manager.get_component_by_id("deep_research_agent.markdown_download") | |
final_update = { | |
stop_button_comp: gr.update(interactive=False, value="⏹️ Stopping...") | |
} | |
if agent and task and not task.done(): | |
logger.info("Signalling DeepResearchAgent to stop.") | |
try: | |
# Assuming stop is synchronous or sets a flag quickly | |
await agent.stop() | |
except Exception as e: | |
logger.error(f"Error calling agent.stop(): {e}") | |
# The run_deep_research loop should detect the stop and exit. | |
# We yield an intermediate "Stopping..." state. The final reset is done by run_deep_research. | |
# Try to show the final report if available after stopping | |
await asyncio.sleep(1.5) # Give agent a moment to write final files potentially | |
report_file_path = None | |
if task_id and base_save_dir: | |
report_file_path = os.path.join(base_save_dir, str(task_id), "report.md") | |
if report_file_path and os.path.exists(report_file_path): | |
report_content = _read_file_safe(report_file_path) | |
if report_content: | |
final_update[markdown_display_comp] = gr.update( | |
value=report_content + "\n\n---\n*Research stopped by user.*") | |
final_update[markdown_download_comp] = gr.File(value=report_file_path, label=f"Report ({task_id}.md)", | |
interactive=True) | |
else: | |
final_update[markdown_display_comp] = gr.update( | |
value="# Research Stopped\n\n*Error reading final report file after stop.*") | |
else: | |
final_update[markdown_display_comp] = gr.update(value="# Research Stopped by User") | |
# Keep start button disabled, run_deep_research finally block will re-enable it. | |
final_update[start_button_comp] = gr.update(interactive=False) | |
else: | |
logger.warning("Stop clicked but no active research task found.") | |
# Reset UI state just in case | |
final_update = { | |
start_button_comp: gr.update(interactive=True), | |
stop_button_comp: gr.update(interactive=False), | |
webui_manager.get_component_by_id("deep_research_agent.research_task"): gr.update(interactive=True), | |
webui_manager.get_component_by_id("deep_research_agent.resume_task_id"): gr.update(interactive=True), | |
webui_manager.get_component_by_id("deep_research_agent.max_iteration"): gr.update(interactive=True), | |
webui_manager.get_component_by_id("deep_research_agent.max_query"): gr.update(interactive=True), | |
} | |
return final_update | |
async def update_mcp_server(mcp_file: str, webui_manager: WebuiManager): | |
""" | |
Update the MCP server. | |
""" | |
if hasattr(webui_manager, "dr_agent") and webui_manager.dr_agent: | |
logger.warning("⚠️ Close controller because mcp file has changed!") | |
await webui_manager.dr_agent.close_mcp_client() | |
if not mcp_file or not os.path.exists(mcp_file) or not mcp_file.endswith('.json'): | |
logger.warning(f"{mcp_file} is not a valid MCP file.") | |
return None, gr.update(visible=False) | |
with open(mcp_file, 'r') as f: | |
mcp_server = json.load(f) | |
return json.dumps(mcp_server, indent=2), gr.update(visible=True) | |
def create_deep_research_agent_tab(webui_manager: WebuiManager): | |
""" | |
Creates a deep research agent tab | |
""" | |
input_components = set(webui_manager.get_components()) | |
tab_components = {} | |
with gr.Group(): | |
with gr.Row(): | |
mcp_json_file = gr.File(label="MCP server json", interactive=True, file_types=[".json"]) | |
mcp_server_config = gr.Textbox(label="MCP server", lines=6, interactive=True, visible=False) | |
with gr.Group(): | |
research_task = gr.Textbox(label="Research Task", lines=5, | |
value="Give me a detailed travel plan to Switzerland from June 1st to 10th.", | |
interactive=True) | |
with gr.Row(): | |
resume_task_id = gr.Textbox(label="Resume Task ID", value="", | |
interactive=True) | |
parallel_num = gr.Number(label="Parallel Agent Num", value=1, | |
precision=0, | |
interactive=True) | |
max_query = gr.Textbox(label="Research Save Dir", value="./tmp/deep_research", | |
interactive=True) | |
with gr.Row(): | |
stop_button = gr.Button("⏹️ Stop", variant="stop", scale=2) | |
start_button = gr.Button("▶️ Run", variant="primary", scale=3) | |
with gr.Group(): | |
markdown_display = gr.Markdown(label="Research Report") | |
markdown_download = gr.File(label="Download Research Report", interactive=False) | |
tab_components.update( | |
dict( | |
research_task=research_task, | |
parallel_num=parallel_num, | |
max_query=max_query, | |
start_button=start_button, | |
stop_button=stop_button, | |
markdown_display=markdown_display, | |
markdown_download=markdown_download, | |
resume_task_id=resume_task_id, | |
mcp_json_file=mcp_json_file, | |
mcp_server_config=mcp_server_config, | |
) | |
) | |
webui_manager.add_components("deep_research_agent", tab_components) | |
webui_manager.init_deep_research_agent() | |
async def update_wrapper(mcp_file): | |
"""Wrapper for handle_pause_resume.""" | |
update_dict = await update_mcp_server(mcp_file, webui_manager) | |
yield update_dict | |
mcp_json_file.change( | |
update_wrapper, | |
inputs=[mcp_json_file], | |
outputs=[mcp_server_config, mcp_server_config] | |
) | |
dr_tab_outputs = list(tab_components.values()) | |
all_managed_inputs = set(webui_manager.get_components()) | |
# --- Define Event Handler Wrappers --- | |
async def start_wrapper(comps: Dict[Component, Any]) -> AsyncGenerator[Dict[Component, Any], None]: | |
async for update in run_deep_research(webui_manager, comps): | |
yield update | |
async def stop_wrapper() -> AsyncGenerator[Dict[Component, Any], None]: | |
update_dict = await stop_deep_research(webui_manager) | |
yield update_dict | |
# --- Connect Handlers --- | |
start_button.click( | |
fn=start_wrapper, | |
inputs=all_managed_inputs, | |
outputs=dr_tab_outputs | |
) | |
stop_button.click( | |
fn=stop_wrapper, | |
inputs=None, | |
outputs=dr_tab_outputs | |
) | |