# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import abstractmethod from collections.abc import Sequence from typing import TYPE_CHECKING, Any from vllm.entrypoints.openai.protocol import DeltaMessage from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.tokenizers import TokenizerLike if TYPE_CHECKING: from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ResponsesRequest, ) else: ChatCompletionRequest = Any ResponsesRequest = Any class BaseThinkingReasoningParser(ReasoningParser): """ Base class for reasoning parsers that use thinking tokens. This class provides common functionality for parsers that use start and end tokens to delimit reasoning content ( e.g., ..., ...). Subclasses must implement the start and end tokens via abstract properties. """ @property @abstractmethod def start_token(self) -> str: """The token that starts reasoning content.""" raise NotImplementedError @property @abstractmethod def end_token(self) -> str: """The token that ends reasoning content.""" raise NotImplementedError def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " "constructor during construction." ) if not self.start_token or not self.end_token: raise ValueError("start_token and end_token must be defined in subclasses") self.start_token_id = self.vocab.get(self.start_token) self.end_token_id = self.vocab.get(self.end_token) if self.start_token_id is None or self.end_token_id is None: raise RuntimeError( f"{self.__class__.__name__} reasoning parser could not locate " "think start/end tokens in the tokenizer!" ) def is_reasoning_end(self, input_ids: list[int]) -> bool: start_token_id = self.start_token_id end_token_id = self.end_token_id for i in range(len(input_ids) - 1, -1, -1): if input_ids[i] == start_token_id: return False if input_ids[i] == end_token_id: return True return False def is_reasoning_end_streaming( self, input_ids: list[int], delta_ids: list[int] ) -> bool: end_token_id = self.end_token_id return end_token_id in delta_ids def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ Extract the content after the end tokens """ if self.end_token_id not in input_ids[:-1]: return [] else: return input_ids[input_ids.index(self.end_token_id) + 1 :] def extract_reasoning_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> DeltaMessage | None: """ Extract reasoning content from a delta message. Handles streaming output where previous + delta = current. Uses token IDs for faster processing. """ # Skip single special tokens if len(delta_token_ids) == 1 and ( delta_token_ids[0] in [self.start_token_id, self.end_token_id] ): return None # Check if start token is present in previous or delta. # Keep compatibility with models that don't generate start tokens. if self.start_token_id in previous_token_ids: if self.end_token_id in delta_token_ids: # start token in previous, end token in delta, # extract reasoning content end_index = delta_text.find(self.end_token) reasoning = delta_text[:end_index] content = delta_text[end_index + len(self.end_token) :] return DeltaMessage( reasoning=reasoning, content=content if content else None ) elif self.end_token_id in previous_token_ids: # start token in previous, end token in previous, # reasoning content continues return DeltaMessage(content=delta_text) else: # start token in previous, no end token in previous or delta, # reasoning content continues return DeltaMessage(reasoning=delta_text) elif self.start_token_id in delta_token_ids: if self.end_token_id in delta_token_ids: # start token in delta, end token in delta, # extract reasoning content start_index = delta_text.find(self.start_token) end_index = delta_text.find(self.end_token) reasoning = delta_text[start_index + len(self.start_token) : end_index] content = delta_text[end_index + len(self.end_token) :] return DeltaMessage( reasoning=reasoning, content=content if content else None ) else: # start token in delta, no end token in delta, # reasoning content continues return DeltaMessage(reasoning=delta_text) else: # not find thinking start token return DeltaMessage(content=delta_text) def extract_reasoning( self, model_output: str, request: ChatCompletionRequest | ResponsesRequest ) -> tuple[str | None, str | None]: """ Extract reasoning content from the model output. This is the base implementation that works for most models. Subclasses can override this method for specific behavior. """ # Check if the start token is present in the model output, remove it # if it is present. model_output_parts = model_output.partition(self.start_token) model_output = ( model_output_parts[2] if model_output_parts[1] else model_output_parts[0] ) # For models that may not generate start token, # assume the reasoning content is always at the start. if self.end_token not in model_output: return model_output, None else: reasoning, _, content = model_output.partition(self.end_token) # If generation stops right after end-of-think, return null content final_content = content or None return reasoning, final_content