Files
Domyn-Small-v1.0/reasoning_parser_plugin.py

298 lines
12 KiB
Python
Raw Permalink Normal View History

"""Reasoning parser plugin for Domyn-Small ``<think>...</think>`` outputs.
Loaded into vLLM with ``--reasoning-parser-plugin <path>`` and selected via
``--reasoning-parser think_block``. The parser splits each model output on
the literal ``</think>`` marker: everything before it is reasoning,
everything after is final content.
See :class:`ThinkBlockReasoningParser` for the streaming state machine and
how per-request thinking-on/off is discovered.
"""
from __future__ import annotations
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from vllm.reasoning import ReasoningParser, ReasoningParserManager
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
# Literal markers emitted by the Domyn-Small chat template. `<think>` is
# pre-emitted by the prompt, so model output never starts with it; only `</think>`
# actually has to be detected at runtime.
START = "<think>"
END = "</think>"
def _max_suffix_prefix(s: str, marker: str) -> str:
"""Longest non-empty suffix of ``s`` that is also a prefix of ``marker``.
Used to decide how many trailing bytes of the streaming buffer must be
held back if those bytes could still grow into ``marker`` on the next
delta, releasing them now would fragment the marker across deltas (e.g.
emitting ``</thi`` and then ``nk>``).
"""
for i in range(min(len(marker) - 1, len(s)), 0, -1):
if s.endswith(marker[:i]):
return s[-i:]
return ""
@ReasoningParserManager.register_module("think_block")
class ThinkBlockReasoningParser(ReasoningParser):
"""Splits model output on the literal ``</think>`` marker.
**Streaming.** Olmo3-style buffered state machine: incoming text is
accumulated in :attr:`_buffer` and only released when the marker is
either confirmed (split point reached) or ruled out (the buffer tail
can no longer be a prefix of ``</think>``). This guarantees the marker
is never fragmented across deltas.
**Per-request lane.** The initial lane (``"reasoning"`` vs
``"content"``) is set from the request itself: ``True`` if
``chat_template_kwargs.enable_thinking`` (or ``.thinking``) is truthy,
or if any system message contains the literal ``"thinking on"``
directive mirroring the chat template's own detection.
**Request discovery.** vLLM instantiates the parser per request from
inside ``create_chat_completion(self, request, ...)``, but does not
pass the request to the constructor. We recover it by walking the call
stack at ``__init__`` time, inspecting only each frame's *function
arguments* (so we don't accidentally match request-shaped objects in
module globals or unrelated locals). If no request is found we fall
back to ``thinking=off``, which keeps tool-call streaming working out
of the box.
"""
def __init__(self, tokenizer, *args, **kwargs) -> None:
# Base ReasoningParser only accepts `tokenizer`; swallow any extras so
# the registration signature stays compatible across vLLM versions.
super().__init__(tokenizer)
self._buffer: str = ""
# Current lane for streaming output: "reasoning" while inside
# <think>...</think>, "content" otherwise. Locked to "content" once
# `</think>` is observed.
self._state: str = "content"
# Tracks whether we have applied per-request configuration yet —
# stack-walking covers the streaming path; `extract_reasoning` also
# configures on the first non-streaming call as a safety net.
self._configured: bool = False
request = self._find_request_in_stack()
if request is not None:
self._configure_for_request(request)
@staticmethod
def _looks_like_request(obj) -> bool:
"""Duck-typed check for ChatCompletionRequest / ResponsesRequest.
Avoids importing vLLM's protocol module, which differs across forks
and isn't guaranteed to be importable at plugin load time.
"""
return hasattr(obj, "messages") and (
hasattr(obj, "chat_template_kwargs") or hasattr(obj, "stream")
)
@classmethod
def _find_request_in_stack(cls, max_depth: int = 12):
"""Locate the in-flight request by scanning caller-frame arguments.
Walks a bounded number of caller frames via ``sys._getframe`` /
``frame.f_back`` and inspects only each frame's *function
arguments* never its full locals. This matches vLLM's
``create_chat_completion(self, request, ...)`` signature and avoids
matching request-shaped objects that happen to live in module
globals or unrelated locals (e.g. test fixtures).
We deliberately avoid :func:`inspect.stack`, which reads source
files via ``linecache`` and builds ``FrameInfo`` objects for the
whole stack on every call measurable overhead per request under
high concurrency, since parser construction is per-request and
runs under the GIL on the serving event loop.
"""
import sys
try:
frame = sys._getframe(1)
except Exception:
return None
depth = 0
while frame is not None and depth < max_depth:
code = frame.f_code
n_args = code.co_argcount + code.co_kwonlyargcount
for name in code.co_varnames[:n_args]:
value = frame.f_locals.get(name)
if cls._looks_like_request(value):
return value
frame = frame.f_back
depth += 1
return None
def _configure_for_request(self, request) -> None:
"""Set initial streaming lane from the request's thinking flag."""
self._state = "reasoning" if self._thinking_was_enabled(request) else "content"
self._configured = True
def _decode(self, ids: Sequence[int]) -> str:
# `skip_special_tokens=False` is required: `</think>` may be tokenized
# as (or contain) special tokens that the default decode would strip,
# which would silently break marker detection.
try:
return self.model_tokenizer.decode(list(ids), skip_special_tokens=False)
except Exception:
return ""
@property
def reasoning_start_str(self) -> str | None:
return START
@property
def reasoning_end_str(self) -> str | None:
return END
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return END in self._decode(input_ids)
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool:
# Decode a 64-token tail window so the marker is detected even when
# it straddles the previous-vs-delta token boundary (BPE may split
# `</think>` across multiple tokens, especially around punctuation).
tail = list(input_ids)[-64:]
return END in self._decode(tail)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
text = self._decode(input_ids)
idx = text.rfind(END)
if idx < 0:
return []
try:
return self.model_tokenizer.encode(
text[idx + len(END):], add_special_tokens=False
)
except Exception:
return []
def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
text = self._decode(token_ids)
idx = text.find(END)
prefix = text if idx < 0 else text[:idx]
try:
return len(self.model_tokenizer.encode(prefix, add_special_tokens=False))
except Exception:
return 0
def extract_reasoning(
self,
model_output: str,
request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]:
"""Split a full (non-streaming) output into ``(reasoning, content)``.
Returns ``(None, content)`` when the request has thinking disabled
and the output contains no marker the chat template pre-emits
``<think></think>`` in the prompt in that case, so a marker-less
output is purely the answer.
"""
# Configure streaming state as a side effect: a fork's serving layer
# may call this before streaming starts, and we don't want the
# streaming path to fall back to the `thinking=off` default if the
# request actually had thinking enabled.
if not self._configured:
self._configure_for_request(request)
s = model_output
if s.startswith(START):
s = s[len(START):]
if END in s:
reasoning, _, content = s.partition(END)
return (reasoning.strip("\n") or None, content.lstrip("\n") or None)
# No `</think>` in output: only treat the text as truncated reasoning
# if we have positive evidence that thinking was enabled — otherwise
# it is the final answer.
if self._thinking_was_enabled(request):
return (s.strip("\n") or None, None)
return (None, s.lstrip("\n") or None)
@staticmethod
def _thinking_was_enabled(request) -> bool:
"""Whether ``request`` asked for reasoning to be emitted.
Mirrors the chat template's own detection so the parser stays in
lockstep with prompt construction: enabled iff
``chat_template_kwargs.enable_thinking`` (or ``.thinking``) is
truthy, or any system message contains the literal ``"thinking on"``
directive (case-insensitive).
"""
kwargs = getattr(request, "chat_template_kwargs", None) or {}
if kwargs.get("enable_thinking") or kwargs.get("thinking"):
return True
messages = getattr(request, "messages", None) or []
for m in messages:
role = m.get("role") if isinstance(m, dict) else getattr(m, "role", None)
if role != "system":
continue
content = m.get("content") if isinstance(m, dict) else getattr(m, "content", None)
if isinstance(content, str) and "thinking on" in content.lower():
return True
return False
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> "DeltaMessage | None":
"""Emit one ``DeltaMessage`` per delta, routed to reasoning or content.
The marker ``</think>`` is never emitted to the client. Trailing
bytes of the buffer that *could* still grow into the marker on the
next delta are held back, so the marker is never fragmented across
deltas (e.g. ``</thi`` ... ``nk>``). When the marker is observed,
pre-marker bytes go to the current lane and post-marker bytes go
to ``content``; the lane is then locked to ``content``.
"""
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
self._buffer += delta_text
# Case 1 — marker fully present in the buffer: split and switch lane.
# The pre-marker chunk stays on the *current* lane (reasoning if we
# were inside <think>, content otherwise); the post-marker chunk
# always goes to content; the lane is locked to content afterwards.
idx = self._buffer.find(END)
if idx >= 0:
pre = self._buffer[:idx]
post = self._buffer[idx + len(END):]
self._buffer = ""
pre_lane = self._state
self._state = "content"
if not pre and not post:
return None
fields: dict = {}
if pre:
fields[pre_lane] = pre
if post:
# `.get` covers the edge case where pre_lane is already
# "content" and both pre and post are non-empty — they get
# concatenated into a single content delta.
fields["content"] = fields.get("content", "") + post
return DeltaMessage(**fields)
# Case 2 — no marker yet: release everything except a possible
# partial-marker tail, which we retain for the next delta.
held = _max_suffix_prefix(self._buffer, END)
safe_end = len(self._buffer) - len(held)
if safe_end == 0:
return None
chunk = self._buffer[:safe_end]
self._buffer = self._buffer[safe_end:]
return DeltaMessage(**{self._state: chunk})