import ast import json import uuid from typing import Any, Dict, List, Optional, Sequence, Union import regex as re from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionToolsParam, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @ToolParserManager.register_module("qwen3_coder") class Qwen3CoderToolParser(ToolParser): """ Tool parser for Qwen3 models using XML-style tool call format: value Port of vllm-original qwen3coder_tool_parser.py to vllm 0.6.3 API. """ def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: List[Dict] = [] # Base class uses int; we override with string IDs self.current_tool_id: Optional[str] = None # type: ignore[assignment] self.streamed_args_for_tool: List[str] = [] self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.tool_call_prefix: str = "(.*?)", re.DOTALL) self.tool_call_regex = re.compile( r"(.*?)|(.*?)$", re.DOTALL) self.tool_call_function_regex = re.compile( r"||(?=)|$)", re.DOTALL) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ToolParser " "constructor during construction.") self.tool_call_start_token_id = self.vocab.get( self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if (self.tool_call_start_token_id is None or self.tool_call_end_token_id is None): raise RuntimeError( "Qwen3 XML Tool parser could not locate tool call start/end " "tokens in the tokenizer!") logger.debug("vLLM Successfully imported tool parser %s !", self.__class__.__name__) def _generate_tool_call_id(self) -> str: return f"call_{uuid.uuid4().hex[:24]}" def _reset_streaming_state(self) -> None: self.current_tool_index = 0 self.is_tool_call_started = False self.header_sent = False self.current_tool_id = None self.current_function_name: Optional[str] = None self.current_param_name: Optional[str] = None self.current_param_value: str = "" self.param_count = 0 self.in_param = False self.in_function = False self.accumulated_text: str = "" self.json_started = False self.json_closed = False self.accumulated_params: Dict[str, Any] = {} self.streaming_request: Optional[ChatCompletionRequest] = None def _get_arguments_config( self, func_name: str, tools: Optional[List[ChatCompletionToolsParam]]) -> Dict: if tools is None: return {} for config in tools: if not hasattr(config, "type") or not ( hasattr(config, "function") and hasattr(config.function, "name")): continue if config.type == "function" and config.function.name == func_name: if not hasattr(config.function, "parameters"): return {} params = config.function.parameters if isinstance(params, dict) and "properties" in params: return params["properties"] elif isinstance(params, dict): return params else: return {} logger.debug("Tool '%s' is not defined in the tools list.", func_name) return {} def _convert_param_value(self, param_value: str, param_name: str, param_config: Dict, func_name: str) -> Any: if param_value.lower() == "null": return None if param_name not in param_config: if param_config != {}: logger.debug( "Parsed parameter '%s' is not defined in tool '%s', " "returning string value.", param_name, func_name) return param_value if (isinstance(param_config[param_name], dict) and "type" in param_config[param_name]): param_type = str( param_config[param_name]["type"]).strip().lower() else: param_type = "string" if param_type in ["string", "str", "text", "varchar", "char", "enum"]: return param_value elif (param_type.startswith("int") or param_type.startswith("uint") or param_type.startswith("long") or param_type.startswith("short") or param_type.startswith("unsigned")): try: return int(param_value) except (ValueError, TypeError): return param_value elif param_type.startswith("num") or param_type.startswith("float"): try: v = float(param_value) return int(v) if v - int(v) == 0 else v except (ValueError, TypeError): return param_value elif param_type in ["boolean", "bool", "binary"]: lower = param_value.lower() if lower not in ["true", "false"]: logger.debug( "Parameter '%s' value '%s' is not boolean in tool '%s'.", param_name, param_value, func_name) return lower == "true" else: if (param_type in ["object", "array", "arr"] or param_type.startswith("dict") or param_type.startswith("list")): try: return json.loads(param_value) except (json.JSONDecodeError, TypeError, ValueError): pass try: return ast.literal_eval(param_value) except (ValueError, SyntaxError, TypeError): pass return param_value def _parse_xml_function_call( self, function_call_str: str, tools: Optional[List[ChatCompletionToolsParam]]) -> ToolCall: end_index = function_call_str.index(">") function_name = function_call_str[:end_index] param_config = self._get_arguments_config(function_name, tools) parameters = function_call_str[end_index + 1:] param_dict: Dict[str, Any] = {} for match_text in self.tool_call_parameter_regex.findall(parameters): idx = match_text.index(">") param_name = match_text[:idx] param_value = str(match_text[idx + 1:]) if param_value.startswith("\n"): param_value = param_value[1:] if param_value.endswith("\n"): param_value = param_value[:-1] param_dict[param_name] = self._convert_param_value( param_value, param_name, param_config, function_name) return ToolCall( type="function", function=FunctionCall( name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False))) def _get_function_calls(self, model_output: str) -> List[str]: matched_ranges = self.tool_call_regex.findall(model_output) raw_tool_calls = [ match[0] if match[0] else match[1] for match in matched_ranges ] if not raw_tool_calls: raw_tool_calls = [model_output] raw_function_calls: List[tuple] = [] for tool_call in raw_tool_calls: raw_function_calls.extend( self.tool_call_function_regex.findall(tool_call)) return [match[0] if match[0] else match[1] for match in raw_function_calls] def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: if self.tool_call_prefix not in model_output: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) try: function_calls = self._get_function_calls(model_output) if not function_calls: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) tool_calls = [ self._parse_xml_function_call(fc, request.tools) for fc in function_calls ] self.prev_tool_call_arr.clear() for tc in tool_calls: self.prev_tool_call_arr.append({ "name": tc.function.name, "arguments": tc.function.arguments, }) content_index = model_output.find(self.tool_call_start_token) idx = model_output.find(self.tool_call_prefix) content_index = content_index if content_index >= 0 else idx content = model_output[:content_index] return ExtractedToolCallInformation( tools_called=bool(tool_calls), tool_calls=tool_calls, content=content if content else None, ) except Exception: logger.exception("Error extracting tool call from response.") return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: if not previous_text: self._reset_streaming_state() self.streaming_request = request if not delta_text: if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: complete_calls = len( self.tool_call_complete_regex.findall(current_text)) if complete_calls > 0 and self.prev_tool_call_arr: open_calls = ( current_text.count(self.tool_call_start_token) - current_text.count(self.tool_call_end_token)) if open_calls == 0: return DeltaMessage(content="") elif not self.is_tool_call_started and current_text: return DeltaMessage(content="") return None self.accumulated_text = current_text if self.json_closed and not self.in_function: tool_ends = current_text.count(self.tool_call_end_token) if tool_ends > self.current_tool_index: self.current_tool_index += 1 self.header_sent = False self.param_count = 0 self.json_started = False self.json_closed = False self.accumulated_params = {} tool_starts = current_text.count(self.tool_call_start_token) if self.current_tool_index >= tool_starts: self.is_tool_call_started = False return None if not self.is_tool_call_started: if (self.tool_call_start_token_id in delta_token_ids or self.tool_call_start_token in delta_text): self.is_tool_call_started = True if self.tool_call_start_token in delta_text: content_before = delta_text[:delta_text.index( self.tool_call_start_token)] if content_before: return DeltaMessage(content=content_before) return None else: if (current_text.rstrip().endswith(self.tool_call_end_token) and delta_text.strip() == ""): return None return DeltaMessage(content=delta_text) tool_starts_count = current_text.count(self.tool_call_start_token) if self.current_tool_index >= tool_starts_count: return None # Locate the current tool call's text slice tool_start_positions: List[int] = [] search = 0 while True: search = current_text.find(self.tool_call_start_token, search) if search == -1: break tool_start_positions.append(search) search += len(self.tool_call_start_token) if self.current_tool_index >= len(tool_start_positions): return None tool_start_idx = tool_start_positions[self.current_tool_index] tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) if tool_end_idx == -1: tool_text = current_text[tool_start_idx:] else: tool_text = current_text[tool_start_idx:tool_end_idx + len(self.tool_call_end_token)] if not self.header_sent: if self.tool_call_prefix in tool_text: func_start = (tool_text.find(self.tool_call_prefix) + len(self.tool_call_prefix)) func_end = tool_text.find(">", func_start) if func_end != -1: self.current_function_name = tool_text[func_start:func_end] self.current_tool_id = self._generate_tool_call_id() self.header_sent = True self.in_function = True self.prev_tool_call_arr.append({ "name": self.current_function_name, "arguments": "{}", }) self.streamed_args_for_tool.append("") return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, id=self.current_tool_id, function=DeltaFunctionCall( name=self.current_function_name, arguments=""), type="function", ) ]) return None if self.in_function: if not self.json_started: self.json_started = True self.streamed_args_for_tool[self.current_tool_index] += "{" return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, function=DeltaFunctionCall(arguments="{"), ) ]) # Collect all complete parameters in one pass (speculative-decode safe) param_starts: List[int] = [] search = 0 while True: search = tool_text.find(self.parameter_prefix, search) if search == -1: break param_starts.append(search) search += len(self.parameter_prefix) json_fragments: List[str] = [] while not self.in_param and self.param_count < len(param_starts): param_idx = param_starts[self.param_count] param_start = param_idx + len(self.parameter_prefix) remaining = tool_text[param_start:] if ">" not in remaining: break name_end = remaining.find(">") current_param_name = remaining[:name_end] value_start = param_start + name_end + 1 value_text = tool_text[value_start:] if value_text.startswith("\n"): value_text = value_text[1:] param_end_idx = value_text.find(self.parameter_end_token) if param_end_idx == -1: next_param = value_text.find(self.parameter_prefix) func_end = value_text.find(self.function_end_token) if next_param != -1 and (func_end == -1 or next_param < func_end): param_end_idx = next_param elif func_end != -1: param_end_idx = func_end else: tool_end_in_value = value_text.find( self.tool_call_end_token) if tool_end_in_value != -1: param_end_idx = tool_end_in_value else: break if param_end_idx == -1: break param_value = value_text[:param_end_idx] if param_value.endswith("\n"): param_value = param_value[:-1] self.accumulated_params[current_param_name] = param_value param_config = self._get_arguments_config( self.current_function_name or "", self.streaming_request.tools if self.streaming_request else None) converted = self._convert_param_value( param_value, current_param_name, param_config, self.current_function_name or "") serialized = json.dumps(converted, ensure_ascii=False) sep = "" if self.param_count == 0 else ", " json_fragments.append( f'{sep}"{current_param_name}": {serialized}') self.param_count += 1 if json_fragments: combined = "".join(json_fragments) if self.current_tool_index < len(self.streamed_args_for_tool): self.streamed_args_for_tool[ self.current_tool_index] += combined else: logger.warning( "streamed_args_for_tool out of sync: index=%d len=%d", self.current_tool_index, len(self.streamed_args_for_tool)) return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, function=DeltaFunctionCall(arguments=combined), ) ]) # Emit closing brace when is seen (after params are done) if not self.json_closed and self.function_end_token in tool_text: self.json_closed = True func_start = (tool_text.find(self.tool_call_prefix) + len(self.tool_call_prefix)) func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: try: parsed_tool = self._parse_xml_function_call( tool_text[func_start:func_content_end], self.streaming_request.tools if self.streaming_request else None) if self.current_tool_index < len( self.prev_tool_call_arr): self.prev_tool_call_arr[ self.current_tool_index]["arguments"] = ( parsed_tool.function.arguments) except Exception: logger.debug("Failed to parse tool call during " "streaming: %s", tool_text, exc_info=True) if self.current_tool_index < len(self.streamed_args_for_tool): self.streamed_args_for_tool[ self.current_tool_index] += "}" else: logger.warning( "streamed_args_for_tool out of sync: index=%d len=%d", self.current_tool_index, len(self.streamed_args_for_tool)) result = DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, function=DeltaFunctionCall(arguments="}"), ) ]) self.in_function = False self.accumulated_params = {} return result return None