# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # code modified from deepseekv3_tool_parser.py from collections.abc import Sequence import regex as re from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) logger = init_logger(__name__) class KimiK2ToolParser(ToolParser): def __init__(self, tokenizer: TokenizerLike): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[ str ] = [] # map what has been streamed for each tool so far to a list # Section-level state management to prevent token leakage self.in_tool_section: bool = False self.token_buffer: str = "" # Buffer size: empirical worst-case for longest marker (~30 chars) * 2 # + safety margin for unicode + partial overlap. Prevents unbounded growth. self.buffer_max_size: int = 1024 self.section_char_count: int = 0 # Track characters processed in tool section self.max_section_chars: int = 8192 # Force exit if section exceeds this self._buffer_overflow_logged: bool = False # Log overflow once per session # Support both singular and plural variants self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" self.tool_calls_end_token: str = "<|tool_calls_section_end|>" self.tool_calls_start_token_variants: list[str] = [ "<|tool_calls_section_begin|>", "<|tool_call_section_begin|>", # singular variant ] self.tool_calls_end_token_variants: list[str] = [ "<|tool_calls_section_end|>", "<|tool_call_section_end|>", # singular variant ] self.tool_call_start_token: str = "<|tool_call_begin|>" self.tool_call_end_token: str = "<|tool_call_end|>" self.tool_call_regex = re.compile( r"<\|tool_call_begin\|>\s*(?P[^<]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P(?:(?!<\|tool_call_begin\|>).)*?)\s*<\|tool_call_end\|>", re.DOTALL, ) self.stream_tool_call_portion_regex = re.compile( r"(?P.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P.*)" ) self.stream_tool_call_name_regex = re.compile(r"(?P.+:\d+)\s*") if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) # Get token IDs for all variants self.tool_calls_start_token_ids: list[int] = [ tid for variant in self.tool_calls_start_token_variants if (tid := self.vocab.get(variant)) is not None ] self.tool_calls_end_token_ids: list[int] = [ tid for variant in self.tool_calls_end_token_variants if (tid := self.vocab.get(variant)) is not None ] self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if ( self.tool_calls_start_token_id is None or self.tool_calls_end_token_id is None ): raise RuntimeError( "Kimi-K2 Tool parser could not locate tool call start/end " "tokens in the tokenizer!" ) def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]: """ Check for section begin/end markers in text and strip them. Returns: (cleaned_text, found_section_begin, found_section_end) """ found_begin = False found_end = False cleaned = text # Check for section begin markers (any variant) for variant in self.tool_calls_start_token_variants: if variant in cleaned: cleaned = cleaned.replace(variant, "") found_begin = True # Check for section end markers (any variant) for variant in self.tool_calls_end_token_variants: if variant in cleaned: cleaned = cleaned.replace(variant, "") found_end = True return cleaned, found_begin, found_end def _reset_section_state(self) -> None: """Reset state when exiting tool section.""" self.in_tool_section = False self.token_buffer = "" self.section_char_count = 0 def reset_streaming_state(self) -> None: """ Reset all streaming state. Call this between requests to prevent state leakage when parser instance is reused. """ # Reset section state self._reset_section_state() # Reset parent class state self.current_tool_name_sent = False self.prev_tool_call_arr = [] self.current_tool_id = -1 self.streamed_args_for_tool = [] logger.debug("Streaming state reset") def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if self.tool_calls_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) else: try: # there are two possible captures - between tags, or between a # tag and end-of-string so the result of # findall is an array of tuples where one is a function call and # the other is None function_call_tuples = self.tool_call_regex.findall(model_output) logger.debug("function_call_tuples: %s", function_call_tuples) tool_calls = [] for match in function_call_tuples: function_id, function_args = match # function_id: functions.get_weather:0 or get_weather:0 function_name = function_id.split(":")[0].split(".")[-1] tool_calls.append( ToolCall( id=function_id, type="function", function=FunctionCall( name=function_name, arguments=function_args ), ) ) content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None, ) except Exception: logger.exception("Error in extracting tool call from response.") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) # Flag to defer section exit until after tool parsing completes deferred_section_exit = False # Add delta to buffer for split marker detection self.token_buffer += delta_text # Enforce buffer size limit to prevent memory issues if len(self.token_buffer) > self.buffer_max_size: if not self._buffer_overflow_logged: logger.warning( "Token buffer exceeded max size (%d bytes), flushing excess. " "This may indicate very long markers or unusual tokenization.", self.buffer_max_size, ) self._buffer_overflow_logged = True # Keep only the most recent content that might contain partial markers self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :] # Check buffer for section markers (handles split tokens) buffered_text, found_section_begin, found_section_end = ( self._check_and_strip_markers(self.token_buffer) ) # Track section state transitions if found_section_begin and not self.in_tool_section: logger.debug("Entering tool section") self.in_tool_section = True self.token_buffer = buffered_text # Use cleaned buffer self.section_char_count = 0 # Reset counter for new section if found_section_end and self.in_tool_section: logger.debug("Detected section end marker") # CRITICAL: Don't exit early if tool_call_end is in this chunk. # Tool parser must emit final arguments/close first to avoid dropping # the final tool update and leaking tokens into reasoning channel. has_tool_end = self.tool_call_end_token_id in delta_token_ids if has_tool_end: # Defer exit until after tool parsing completes deferred_section_exit = True logger.debug("Deferring section exit: tool_call_end in same chunk") self.token_buffer = buffered_text else: # No tool call ending, safe to exit immediately logger.debug("Exiting tool section") remaining = buffered_text self._reset_section_state() # Return remaining text as reasoning content if non-empty if remaining.strip(): return DeltaMessage(content=remaining) # Return empty delta to maintain function contract # (always returns DeltaMessage) return DeltaMessage(content="") else: self.token_buffer = buffered_text # Check if any variant of section start token is in current_token_ids has_section_token = any( tid in current_token_ids for tid in self.tool_calls_start_token_ids ) # Early return: if no section token detected yet, return as reasoning content if not has_section_token and not self.in_tool_section: logger.debug("No tool call tokens found!") # Don't clear buffer - it needs to accumulate partial markers across deltas # Buffer overflow is already protected by lines 215-224 return DeltaMessage(content=delta_text) # Strip section markers from delta_text for subsequent processing # NOTE: This preprocessing happens BEFORE the regex-based tool call # parsing (from PR #24847) to ensure markers are removed cleanly # before pattern matching. No double-stripping occurs because # section markers and tool call markers are distinct. delta_text, _, _ = self._check_and_strip_markers(delta_text) # Error recovery: If in tool section for too long, force exit if self.in_tool_section: self.section_char_count += len(delta_text) if self.section_char_count > self.max_section_chars: logger.warning( "Tool section exceeded max length (%d chars), forcing exit. " "This may indicate malformed model output.", self.max_section_chars, ) self._reset_section_state() # Deferred exit already handled by forced exit above # Return remaining content as reasoning (or empty delta if no content) return DeltaMessage(content=delta_text if delta_text.strip() else "") try: # figure out where we are in the parsing by counting tool call # start & end tags prev_tool_start_count = previous_token_ids.count( self.tool_call_start_token_id ) prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count( self.tool_call_start_token_id ) cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None # case: if we're generating text, OR rounding out a tool call if ( cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count and self.tool_call_end_token not in delta_text ): # CRITICAL FIX: Suppress content if in tool section but # no tool calls started if self.in_tool_section and cur_tool_start_count == 0: logger.debug( "In tool section but no tool calls started yet. " "Suppressing: %s", delta_text, ) # Return empty delta to maintain iterator contract return DeltaMessage(content="") logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text tool_call_portion = ( full_text.split(self.tool_call_start_token)[-1] .split(self.tool_call_end_token)[0] .rstrip() ) delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() # case -- we're starting a new tool call if ( cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count ): if len(delta_token_ids) > 1: tool_call_portion = current_text.split(self.tool_call_start_token)[ -1 ] else: tool_call_portion = None delta = None text_portion = None # set cursors and state appropriately self.current_tool_id += 1 self.current_tool_name_sent = False self.streamed_args_for_tool.append("") logger.debug("Starting on a new tool %s", self.current_tool_id) # case -- we're updating an existing tool call elif ( cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count ): # get the portion of the text that's the tool call tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None # case -- the current tool call is being closed. elif ( cur_tool_start_count == cur_tool_end_count and cur_tool_end_count >= prev_tool_end_count ): if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: logger.debug("attempting to close tool call, but no tool call") # Handle deferred section exit before returning if deferred_section_exit and self.in_tool_section: self._reset_section_state() return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: diff = ( diff.encode("utf-8").decode("unicode_escape") if diff is str else diff ) if '"}' not in delta_text: # Handle deferred section exit before returning if deferred_section_exit and self.in_tool_section: self._reset_section_state() return None end_loc = delta_text.rindex('"}') diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not " "been streamed yet: %s", diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff # Handle deferred section exit before returning if deferred_section_exit and self.in_tool_section: logger.debug("Completing deferred section exit") self._reset_section_state() return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=diff).model_dump( exclude_none=True ), ) ] ) # case -- otherwise we're just generating text else: # Check if we're in tool section - if so, suppress if self.in_tool_section: logger.debug("In tool section, suppressing text generation") # Handle deferred section exit before returning if deferred_section_exit: self._reset_section_state() return DeltaMessage(content="") text = delta_text.replace(self.tool_call_start_token, "") text = text.replace(self.tool_call_end_token, "") delta = DeltaMessage(tool_calls=[], content=text) # Handle deferred section exit before returning if deferred_section_exit and self.in_tool_section: self._reset_section_state() return delta current_tool_call = dict() if tool_call_portion: current_tool_call_matches = self.stream_tool_call_portion_regex.match( tool_call_portion ) if current_tool_call_matches: tool_id, tool_args = current_tool_call_matches.groups() tool_name = tool_id.split(":")[0].split(".")[-1] current_tool_call["id"] = tool_id current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = ( self.stream_tool_call_name_regex.match(tool_call_portion) ) if current_tool_call_name_matches: (tool_id_str,) = current_tool_call_name_matches.groups() tool_name = tool_id_str.split(":")[0].split(".")[-1] current_tool_call["id"] = tool_id_str current_tool_call["name"] = tool_name current_tool_call["arguments"] = "" else: logger.debug("Not enough token") return None # case - we haven't sent the tool name yet. If it's available, send # it. otherwise, wait until it's available. if not self.current_tool_name_sent: if current_tool_call is None: return None function_name: str | None = current_tool_call.get("name") tool_id = current_tool_call.get("id") if function_name: self.current_tool_name_sent = True return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, type="function", id=tool_id, function=DeltaFunctionCall( name=function_name ).model_dump(exclude_none=True), ) ] ) else: return None # case -- otherwise, send the tool call delta # if the tool call portion is None, send the delta as text if tool_call_portion is None: # if there's text but not tool calls, send that - # otherwise None to skip chunk delta = ( DeltaMessage(content=delta_text) if text_portion is not None else None ) return delta # now, the nitty-gritty of tool calls # now we have the portion to parse as tool call. logger.debug( "Trying to parse current tool call with ID %s", self.current_tool_id ) # if we're starting a new tool call, push an empty object in as # a placeholder for the arguments if len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) # main logic for tool parsing here - compare prev. partially-parsed # JSON to the current partially-parsed JSON prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( "arguments" ) cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) logger.debug("against new ones: %s", cur_arguments) # case -- no arguments have been created yet. skip sending a delta. if not cur_arguments and not prev_arguments: logger.debug("Skipping text %s - no arguments", delta_text) delta = None # case -- prev arguments are defined, but non are now. # probably impossible, but not a fatal error - just keep going elif not cur_arguments and prev_arguments: logger.error( "should be impossible to have arguments reset " "mid-call. skipping streaming anything." ) delta = None # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( arguments=cur_arguments ).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] = cur_arguments # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: if ( isinstance(delta_text, str) and cur_arguments != prev_arguments and len(cur_arguments) > len(prev_arguments) and cur_arguments.startswith(prev_arguments) ): delta_arguments = cur_arguments[len(prev_arguments) :] logger.debug("got diff %s", delta_text) delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( arguments=delta_arguments ).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None # handle saving the state for the current tool into # the "prev" list for use in diffing for the next iteration if self.current_tool_id == len(self.prev_tool_call_arr) - 1: self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) # Handle deferred section exit after tool parsing completes if deferred_section_exit and self.in_tool_section: logger.debug("Completing deferred section exit") self._reset_section_state() return delta except Exception: logger.exception("Error trying to handle streaming tool call.") return None # do not stream a delta. skip this token ID.