# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from collections.abc import Sequence import regex as re from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall, ) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.tool_parsers.abstract_tool_parser import ToolParser logger = init_logger(__name__) class FunctionGemmaToolParser(ToolParser): """ Tool parser for Google's FunctionGemma model (google/functiongemma-270m-it). Handles the FunctionGemma function call format: call:func_name{param:value} """ def __init__(self, tokenizer: TokenizerLike): super().__init__(tokenizer) # Streaming state self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] # FunctionGemma tokens self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" # Regex patterns self.tool_call_regex = re.compile( r"call:(\w+)\{(.*?)\}" r"|call:(\w+)\{(.*)", re.DOTALL, ) self.arg_regex = re.compile( r"(\w+):(.*?)", re.DOTALL, ) if self.model_tokenizer: self.tool_call_start_token_ids = self.model_tokenizer.encode( self.tool_call_start_token, add_special_tokens=False ) self.tool_call_end_token_ids = self.model_tokenizer.encode( self.tool_call_end_token, add_special_tokens=False ) else: self.tool_call_start_token_ids = [] self.tool_call_end_token_ids = [] self.buffered_delta_text = "" def _parse_arguments(self, args_str: str) -> dict: """Parse FunctionGemma argument string into a dictionary.""" arguments = {} if not args_str: return arguments matches = self.arg_regex.findall(args_str) for key, value in matches: try: parsed_value = json.loads(value) arguments[key] = parsed_value except json.JSONDecodeError: arguments[key] = value return arguments def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: request = super().adjust_request(request) if request.tools and request.tool_choice != "none": request.skip_special_tokens = False return request def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.tool_call_start_token not in model_output: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) try: matches = self.tool_call_regex.findall(model_output) if not matches: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) tool_calls: list[ToolCall] = [] for match in matches: func_name = match[0] if match[0] else match[2] args_str = match[1] if match[1] else match[3] if not func_name: continue arguments = self._parse_arguments(args_str) tool_calls.append( ToolCall( type="function", function=FunctionCall( name=func_name, arguments=json.dumps(arguments, ensure_ascii=False), ), ) ) if tool_calls: content_end = model_output.find(self.tool_call_start_token) content = ( model_output[:content_end].strip() if content_end > 0 else None ) return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None, ) return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) except Exception: logger.exception("Error extracting tool calls from FunctionGemma response") return ExtractedToolCallInformation( tools_called=False, tool_calls=[], content=model_output ) def _buffer_delta_text(self, delta_text: str) -> str: """Buffer incoming delta text to handle multi-token special sequences.""" potential_start = "" potential_end = "" combined = self.buffered_delta_text + delta_text if combined.endswith(potential_start) or combined.endswith(potential_end): self.buffered_delta_text = "" return combined for tag in [potential_start, potential_end]: for i in range(1, len(tag)): if combined.endswith(tag[:i]): self.buffered_delta_text = combined[-(i):] return combined[:-i] self.buffered_delta_text = "" return combined def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: delta_text = self._buffer_delta_text(delta_text) current_text = previous_text + delta_text if self.tool_call_start_token not in current_text: if delta_text: return DeltaMessage(content=delta_text) return None try: start_count = current_text.count(self.tool_call_start_token) end_count = current_text.count(self.tool_call_end_token) prev_start_count = previous_text.count(self.tool_call_start_token) prev_end_count = previous_text.count(self.tool_call_end_token) if self.tool_call_start_token not in current_text: return DeltaMessage(content=delta_text) # Starting a new function call if start_count > prev_start_count and start_count > end_count: self.current_tool_id += 1 self.current_tool_name_sent = False self.streamed_args_for_tool.append("") self.prev_tool_call_arr.append({}) logger.debug("Starting new tool call %d", self.current_tool_id) return None # In the middle of a function call if start_count > end_count: last_start = current_text.rfind(self.tool_call_start_token) partial_call = current_text[ last_start + len(self.tool_call_start_token) : ] if partial_call.startswith("call:"): func_part = partial_call[5:] if "{" in func_part: func_name = func_part.split("{")[0] args_part = ( func_part.split("{", 1)[1] if "{" in func_part else "" ) if not self.current_tool_name_sent and func_name: self.current_tool_name_sent = True self.prev_tool_call_arr[self.current_tool_id] = { "name": func_name, "arguments": {}, } return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, type="function", id=make_tool_call_id(), function=DeltaFunctionCall( name=func_name ).model_dump(exclude_none=True), ) ] ) if self.current_tool_name_sent and args_part: current_args = self._parse_arguments(args_part) if current_args: current_args_json = json.dumps( current_args, ensure_ascii=False ) prev_streamed = self.streamed_args_for_tool[ self.current_tool_id ] if len(current_args_json) > len(prev_streamed): diff = current_args_json[len(prev_streamed) :] self.streamed_args_for_tool[ self.current_tool_id ] = current_args_json self.prev_tool_call_arr[self.current_tool_id][ "arguments" ] = current_args return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( arguments=diff ).model_dump(exclude_none=True), ) ] ) return None # Function call just ended if end_count > prev_end_count: if self.current_tool_id >= 0 and self.current_tool_id < len( self.prev_tool_call_arr ): all_calls = self.tool_call_regex.findall(current_text) args = {} if self.current_tool_id < len(all_calls): match = all_calls[self.current_tool_id] if match[0]: args_str = match[1] args = self._parse_arguments(args_str) self.prev_tool_call_arr[self.current_tool_id][ "arguments" ] = args if args: args_json = json.dumps(args, ensure_ascii=False) prev_streamed = self.streamed_args_for_tool[ self.current_tool_id ] if len(args_json) > len(prev_streamed): diff = args_json[len(prev_streamed) :] self.streamed_args_for_tool[self.current_tool_id] = ( args_json ) return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall( arguments=diff ).model_dump(exclude_none=True), ) ] ) return None if delta_text: return DeltaMessage(content=delta_text) return None except Exception: logger.exception("Error in streaming tool call extraction") return None