File size: 5,080 Bytes
23dd25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import re
from vllm.entrypoints.openai.protocol import (
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
    ToolParser,
    ToolParserManager,
)
from vllm.logger import init_logger

logger = init_logger(__name__)


@ToolParserManager.register_module("mistral_v3_debug")
class MistralV3DebugToolParser(ToolParser):
    """
    Custom parser for Mistral v3 with detailed logging.
    Ensures OpenAI-compatible tool calls while debugging missing arguments.
    """

    def extract_tool_calls(
        self, model_output: str, request
    ) -> ExtractedToolCallInformation:
        """
        Extracts tool calls from model output using Mistral's special tokens.
        Accepts multiple calls in a comma-separated list, either with or
        without leading/trailing square brackets.
        """

        logger.info(f"πŸ” Extracting tool calls from model output... | {repr(request)}")
        logger.info(f"Raw model output: {model_output}")

        try:
            # Find tool calls inside [TOOL_CALLS][ ... ]
            tool_call_match = re.search(
                r"\[TOOL_CALLS\]\[(.*?)\]", model_output, re.DOTALL
            )
            if not tool_call_match:
                logger.warning(
                    "⚠️ No valid [TOOL_CALLS] block found. Treating as normal content."
                )
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )

            # Extract JSON snippet from inside [TOOL_CALLS][...]
            tool_call_json = tool_call_match.group(1).strip()
            logger.debug(f"πŸ“₯ Extracted JSON snippet: {tool_call_json}")

            # Ensure valid JSON list format
            if not tool_call_json.startswith("["):
                logger.debug("πŸ”§ Wrapping snippet with leading '['")
                tool_call_json = f"[{tool_call_json}"
            if not tool_call_json.endswith("]"):
                logger.debug("πŸ”§ Wrapping snippet with trailing ']'")
                tool_call_json = f"{tool_call_json}]"

            logger.debug(f"πŸ“ Final JSON to parse: {tool_call_json}")
            tool_call_data = json.loads(tool_call_json)

            # Ensure we have a list of tool calls
            if isinstance(tool_call_data, dict):
                logger.debug(
                    "πŸ‘€ Detected single tool call dictionary; converting to a list."
                )
                tool_call_data = [tool_call_data]
            elif not isinstance(tool_call_data, list):
                logger.error(
                    "🚨 Tool call data is neither a list nor a valid object list. Returning as content."
                )
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )

            tool_calls = []
            for i, tool_item in enumerate(tool_call_data):
                logger.debug(f"πŸ› οΈ Processing item {i}: {tool_item}")

                # Ensure each item is a dict with "name" and "arguments"
                if not isinstance(tool_item, dict):
                    logger.error(f"❌ Item {i} is not a JSON object. Skipping.")
                    continue

                name = tool_item.get("name", "unknown_tool")
                args = tool_item.get("arguments", {})

                # Ensure arguments is a dict
                if not isinstance(args, dict):
                    logger.error(
                        f"❌ Arguments for tool '{name}' are not a dict. Using empty dict."
                    )
                    args = {}

                # Convert arguments to a JSON string (for OpenAI-compatible function calls)
                arguments_json = json.dumps(args, ensure_ascii=False)
                logger.debug(f"βœ… Parsed arguments for '{name}': {arguments_json}")

                # Build a single ToolCall object
                tool_calls.append(
                    ToolCall(
                        type="function",
                        id=f"call_{i}",
                        function=FunctionCall(name=name, arguments=arguments_json),
                    )
                )

            logger.info(f"βœ… Successfully extracted {len(tool_calls)} tool call(s).")
            # We have recognized tool calls, so set content=None
            return ExtractedToolCallInformation(
                tools_called=True, tool_calls=tool_calls, content=None
            )

        except json.JSONDecodeError as e:
            logger.error(f"❌ Failed to parse tool calls JSON: {str(e)}")
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

        except Exception as e:
            logger.exception("πŸ”₯ Unexpected error while parsing tool calls.")
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )