|
import json |
|
import re |
|
from vllm.entrypoints.openai.protocol import ( |
|
ExtractedToolCallInformation, |
|
FunctionCall, |
|
ToolCall, |
|
) |
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
|
ToolParser, |
|
ToolParserManager, |
|
) |
|
from vllm.logger import init_logger |
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
@ToolParserManager.register_module("mistral_v3_debug") |
|
class MistralV3DebugToolParser(ToolParser): |
|
""" |
|
Custom parser for Mistral v3 with detailed logging. |
|
Ensures OpenAI-compatible tool calls while debugging missing arguments. |
|
""" |
|
|
|
def extract_tool_calls( |
|
self, model_output: str, request |
|
) -> ExtractedToolCallInformation: |
|
""" |
|
Extracts tool calls from model output using Mistral's special tokens. |
|
Accepts multiple calls in a comma-separated list, either with or |
|
without leading/trailing square brackets. |
|
""" |
|
|
|
logger.info(f"π Extracting tool calls from model output... | {repr(request)}") |
|
logger.info(f"Raw model output: {model_output}") |
|
|
|
try: |
|
|
|
tool_call_match = re.search( |
|
r"\[TOOL_CALLS\]\[(.*?)\]", model_output, re.DOTALL |
|
) |
|
if not tool_call_match: |
|
logger.warning( |
|
"β οΈ No valid [TOOL_CALLS] block found. Treating as normal content." |
|
) |
|
return ExtractedToolCallInformation( |
|
tools_called=False, tool_calls=[], content=model_output |
|
) |
|
|
|
|
|
tool_call_json = tool_call_match.group(1).strip() |
|
logger.debug(f"π₯ Extracted JSON snippet: {tool_call_json}") |
|
|
|
|
|
if not tool_call_json.startswith("["): |
|
logger.debug("π§ Wrapping snippet with leading '['") |
|
tool_call_json = f"[{tool_call_json}" |
|
if not tool_call_json.endswith("]"): |
|
logger.debug("π§ Wrapping snippet with trailing ']'") |
|
tool_call_json = f"{tool_call_json}]" |
|
|
|
logger.debug(f"π Final JSON to parse: {tool_call_json}") |
|
tool_call_data = json.loads(tool_call_json) |
|
|
|
|
|
if isinstance(tool_call_data, dict): |
|
logger.debug( |
|
"π Detected single tool call dictionary; converting to a list." |
|
) |
|
tool_call_data = [tool_call_data] |
|
elif not isinstance(tool_call_data, list): |
|
logger.error( |
|
"π¨ Tool call data is neither a list nor a valid object list. Returning as content." |
|
) |
|
return ExtractedToolCallInformation( |
|
tools_called=False, tool_calls=[], content=model_output |
|
) |
|
|
|
tool_calls = [] |
|
for i, tool_item in enumerate(tool_call_data): |
|
logger.debug(f"π οΈ Processing item {i}: {tool_item}") |
|
|
|
|
|
if not isinstance(tool_item, dict): |
|
logger.error(f"β Item {i} is not a JSON object. Skipping.") |
|
continue |
|
|
|
name = tool_item.get("name", "unknown_tool") |
|
args = tool_item.get("arguments", {}) |
|
|
|
|
|
if not isinstance(args, dict): |
|
logger.error( |
|
f"β Arguments for tool '{name}' are not a dict. Using empty dict." |
|
) |
|
args = {} |
|
|
|
|
|
arguments_json = json.dumps(args, ensure_ascii=False) |
|
logger.debug(f"β
Parsed arguments for '{name}': {arguments_json}") |
|
|
|
|
|
tool_calls.append( |
|
ToolCall( |
|
type="function", |
|
id=f"call_{i}", |
|
function=FunctionCall(name=name, arguments=arguments_json), |
|
) |
|
) |
|
|
|
logger.info(f"β
Successfully extracted {len(tool_calls)} tool call(s).") |
|
|
|
return ExtractedToolCallInformation( |
|
tools_called=True, tool_calls=tool_calls, content=None |
|
) |
|
|
|
except json.JSONDecodeError as e: |
|
logger.error(f"β Failed to parse tool calls JSON: {str(e)}") |
|
return ExtractedToolCallInformation( |
|
tools_called=False, tool_calls=[], content=model_output |
|
) |
|
|
|
except Exception as e: |
|
logger.exception("π₯ Unexpected error while parsing tool calls.") |
|
return ExtractedToolCallInformation( |
|
tools_called=False, tool_calls=[], content=model_output |
|
) |
|
|