fix(platform): reimplement MiniMax usage accounting patch (#7835)

## Summary
- replace the MiniMax usage accounting monkey patch with a runtime
wrapper implementation instead of source-text rewriting
- preserve MiniMax reasoning-token semantics when `</think>` is missing
by counting the emitted output as reasoning tokens
- add unit coverage for usage tracking helpers and MiniMax
reasoning-token counting

## Why
The previous implementation rewrote `OpenAIServingChat` by matching
exact source blocks. That was brittle against `vllm` source drift and
could crash during early plugin initialization with:
`RuntimeError: Failed to locate expected block while patching
OpenAIServingChat usage accounting.`

This change keeps the usage-accounting backport, but applies it by
wrapping the original stream/full generators and tracking output token
ids at runtime.

For MiniMax reasoning counting, a missing `</think>` should not be
treated as zero reasoning tokens. It can mean the whole output is still
in thinking mode, or that generation stopped before the closing token
was produced. In that case, the emitted output should still be counted
as reasoning.

## Validation
- `pytest -q
tests/ut/patch/platform/test_patch_minimax_usage_accounting.py`
- `vllm serve --help`

Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
Co-authored-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
This commit is contained in:
jack
2026-03-31 16:27:00 +08:00
committed by GitHub
parent 4f259d4fd8
commit 7314bbe2df
2 changed files with 196 additions and 239 deletions

View File

@@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from types import SimpleNamespace
import pytest
from vllm.reasoning.minimax_m2_reasoning_parser import (
MiniMaxM2AppendThinkReasoningParser,
MiniMaxM2ReasoningParser,
)
from vllm_ascend.patch.platform import (
patch_minimax_usage_accounting as minimax_usage_patch,
)
class FakeTokenizer:
def get_vocab(self):
@@ -35,13 +40,13 @@ class FakeTokenizer:
MiniMaxM2ReasoningParser,
[10, 11, 20],
3,
id="minimax-all-tokens-are-reasoning-before-end-token",
id="minimax-no-end-token-means-all-output-is-reasoning",
),
pytest.param(
MiniMaxM2AppendThinkReasoningParser,
[10, 11, 20],
3,
id="append-think-all-tokens-are-reasoning-before-end-token",
id="append-think-no-end-token-means-all-output-is-reasoning",
),
pytest.param(
MiniMaxM2ReasoningParser,
@@ -65,3 +70,73 @@ def test_count_reasoning_tokens(
parser = parser_cls(FakeTokenizer())
assert parser.count_reasoning_tokens(token_ids) == expected_reasoning_tokens
def test_update_usage_tracking_state_tracks_prompt_and_completion_tokens():
state = minimax_usage_patch._create_usage_tracking_state(
num_choices=2,
reasoning_parser=None,
)
res = SimpleNamespace(
prompt_token_ids=[1, 2],
encoder_prompt_token_ids=[3],
num_cached_tokens=4,
outputs=[
SimpleNamespace(index=0, token_ids=(10, 11)),
SimpleNamespace(index=1, token_ids=[20]),
],
)
minimax_usage_patch._update_usage_tracking_state(state, res)
assert state.num_prompt_tokens == 3
assert state.num_cached_tokens == 4
assert state.completion_tokens == [2, 1]
assert state.raw_output_token_ids == [[10, 11], [20]]
def test_make_usage_info_injects_reasoning_token_details():
fake_serving = SimpleNamespace(enable_prompt_tokens_details=True)
usage = minimax_usage_patch._make_usage_info(
fake_serving,
prompt_tokens=3,
completion_tokens=4,
num_cached_tokens=1,
reasoning_tokens=2,
)
payload = usage.model_dump(exclude_none=True)
assert payload["completion_tokens_details"]["reasoning_tokens"] == 2
assert payload["prompt_tokens_details"]["cached_tokens"] == 1
def test_make_full_response_usage_sums_reasoning_tokens():
class FakeServing:
enable_prompt_tokens_details = False
def _make_usage_info(self, **kwargs):
return minimax_usage_patch._make_usage_info(self, **kwargs)
class FakeReasoningParser:
def count_reasoning_tokens(self, token_ids):
return 2 if 2 in token_ids else 0
state = minimax_usage_patch._create_usage_tracking_state(
num_choices=2,
reasoning_parser=FakeReasoningParser(),
)
state.num_prompt_tokens = 3
state.num_cached_tokens = 1
state.final_res = SimpleNamespace(num_cached_tokens=1)
state.completion_tokens = [4, 2]
state.raw_output_token_ids = [[10, 11, 2, 20], [30, 31]]
usage = minimax_usage_patch._make_full_response_usage(FakeServing(), state)
assert usage.prompt_tokens == 3
assert usage.completion_tokens == 6
assert usage.total_tokens == 9
assert usage.completion_tokens_details.reasoning_tokens == 2
assert usage.prompt_tokens_details is None

View File

@@ -19,10 +19,8 @@
from __future__ import annotations
import ast
import textwrap
from collections.abc import Sequence
from pathlib import Path
from collections.abc import AsyncIterator, Sequence
from dataclasses import dataclass
from typing import Any
from vllm.entrypoints.openai.chat_completion import protocol as chat_protocol
@@ -32,46 +30,6 @@ from vllm.entrypoints.openai.engine import protocol as engine_protocol
from vllm.reasoning import minimax_m2_reasoning_parser as minimax_parser
def _extract_class_method_source(
module_path: str,
class_name: str,
method_name: str,
) -> str:
source = Path(module_path).read_text(encoding="utf-8")
tree = ast.parse(source)
for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == class_name:
for item in node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == method_name:
method_source = ast.get_source_segment(source, item)
if method_source is None:
break
return textwrap.dedent(method_source)
raise RuntimeError(f"Unable to extract {class_name}.{method_name} from {module_path}.")
def _install_method(method_name: str, method_source: str) -> None:
namespace: dict[str, Any] = {}
exec(method_source, chat_serving.__dict__, namespace)
method = namespace[method_name]
method.__module__ = OpenAIServingChat.__module__
method.__qualname__ = f"{OpenAIServingChat.__qualname__}.{method_name}"
setattr(OpenAIServingChat, method_name, method)
def _replace_block(
source: str,
old: str,
new: str,
*,
count: int = 1,
) -> str:
if source.count(old) < count:
raise RuntimeError("Failed to locate expected block while patching OpenAIServingChat usage accounting.")
return source.replace(old, new, count)
class CompletionTokenUsageInfo(engine_protocol.OpenAIBaseModel):
reasoning_tokens: int | None = None
audio_tokens: int | None = None
@@ -160,208 +118,132 @@ OpenAIServingChat._count_reasoning_tokens_for_usage = staticmethod(_count_reason
OpenAIServingChat._make_usage_info = _make_usage_info
def _patch_chat_completion_stream_generator() -> None:
method_source = _extract_class_method_source(
chat_serving.__file__,
"OpenAIServingChat",
"chat_completion_stream_generator",
@dataclass
class _UsageTrackingState:
completion_tokens: list[int]
raw_output_token_ids: list[list[int]]
reasoning_parser: Any
num_prompt_tokens: int = 0
num_cached_tokens: int | None = None
final_res: Any = None
def _create_usage_tracking_state(
num_choices: int,
reasoning_parser,
) -> _UsageTrackingState:
return _UsageTrackingState(
completion_tokens=[0] * num_choices,
raw_output_token_ids=[[] for _ in range(num_choices)],
reasoning_parser=reasoning_parser,
)
method_source = _replace_block(
method_source,
"""\
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
""",
"""\
previous_num_tokens = [0] * num_choices
raw_output_token_ids = [[] for _ in range(num_choices)]
finish_reason_sent = [False] * num_choices
""",
def _update_usage_tracking_state(
state: _UsageTrackingState,
res,
) -> None:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids)
state.num_prompt_tokens = num_prompt_tokens
if state.num_cached_tokens is None:
state.num_cached_tokens = res.num_cached_tokens
state.final_res = res
for output in res.outputs:
if 0 <= output.index < len(state.completion_tokens):
token_ids = chat_serving.as_list(output.token_ids)
state.completion_tokens[output.index] += len(token_ids)
state.raw_output_token_ids[output.index].extend(token_ids)
async def _tracked_result_generator(
result_generator: AsyncIterator,
state: _UsageTrackingState,
):
async for res in result_generator:
_update_usage_tracking_state(state, res)
yield res
def _sum_reasoning_tokens_for_usage(
raw_output_token_ids: list[list[int]],
reasoning_parser,
) -> int | None:
if reasoning_parser is None:
return None
return sum(
_count_reasoning_tokens_for_usage(token_ids, reasoning_parser) or 0 for token_ids in raw_output_token_ids
)
method_source = _replace_block(
method_source,
"""\
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens,
)
""",
"""\
if include_continuous_usage:
chunk.usage = self._make_usage_info(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
reasoning_tokens=self._count_reasoning_tokens_for_usage(
raw_output_token_ids[i], reasoning_parser
),
)
""",
def _make_full_response_usage(
self,
state: _UsageTrackingState,
) -> UsageInfo | None:
if state.final_res is None:
return None
return self._make_usage_info(
prompt_tokens=state.num_prompt_tokens,
completion_tokens=sum(state.completion_tokens),
num_cached_tokens=state.num_cached_tokens,
reasoning_tokens=_sum_reasoning_tokens_for_usage(
state.raw_output_token_ids,
state.reasoning_parser,
),
)
method_source = _replace_block(
method_source,
"""\
if include_continuous_usage:
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens,
)
""",
"""\
if include_continuous_usage:
chunk.usage = self._make_usage_info(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
reasoning_tokens=self._count_reasoning_tokens_for_usage(
raw_output_token_ids[i], reasoning_parser
),
)
""",
if not hasattr(OpenAIServingChat, "_ascend_original_chat_completion_full_generator"):
OpenAIServingChat._ascend_original_chat_completion_full_generator = OpenAIServingChat.chat_completion_full_generator
async def _wrapped_chat_completion_full_generator(
self,
request: chat_protocol.ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
model_name: str,
conversation,
tokenizer,
request_metadata: engine_protocol.RequestResponseMetadata,
reasoning_parser=None,
):
num_choices = 1 if request.n is None else request.n
state = _create_usage_tracking_state(num_choices, reasoning_parser)
original_full_generator = self._ascend_original_chat_completion_full_generator
response = await original_full_generator(
request,
_tracked_result_generator(result_generator, state),
request_id,
model_name,
conversation,
tokenizer,
request_metadata,
reasoning_parser,
)
method_source = _replace_block(
method_source,
"""\
previous_num_tokens[i] += len(output.token_ids)
""",
"""\
previous_num_tokens[i] += len(output.token_ids)
raw_output_token_ids[i].extend(as_list(output.token_ids))
""",
)
if not isinstance(response, chat_protocol.ChatCompletionResponse):
return response
method_source = _replace_block(
method_source,
"""\
if include_continuous_usage:
completion_tokens = previous_num_tokens[i]
chunk.usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
""",
"""\
if include_continuous_usage:
completion_tokens = previous_num_tokens[i]
chunk.usage = self._make_usage_info(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
reasoning_tokens=self._count_reasoning_tokens_for_usage(
raw_output_token_ids[i], reasoning_parser
),
)
""",
)
usage = _make_full_response_usage(self, state)
if usage is None:
return response
method_source = _replace_block(
method_source,
"""\
final_usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
if self.enable_prompt_tokens_details and num_cached_tokens:
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=num_cached_tokens
)
""",
"""\
reasoning_tokens = None
if reasoning_parser is not None:
reasoning_tokens = sum(
self._count_reasoning_tokens_for_usage(
token_ids, reasoning_parser
)
or 0
for token_ids in raw_output_token_ids
)
final_usage = self._make_usage_info(
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
num_cached_tokens=num_cached_tokens,
reasoning_tokens=reasoning_tokens,
)
""",
)
method_source = _replace_block(
method_source,
"""\
request_metadata.final_usage_info = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_completion_tokens,
total_tokens=num_prompt_tokens + num_completion_tokens,
)
""",
"""\
reasoning_tokens = None
if reasoning_parser is not None:
reasoning_tokens = sum(
self._count_reasoning_tokens_for_usage(
token_ids, reasoning_parser
)
or 0
for token_ids in raw_output_token_ids
)
request_metadata.final_usage_info = self._make_usage_info(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_completion_tokens,
reasoning_tokens=reasoning_tokens,
)
""",
)
_install_method("chat_completion_stream_generator", method_source)
response.usage = usage
request_metadata.final_usage_info = usage
return response
def _patch_chat_completion_full_generator() -> None:
method_source = _extract_class_method_source(
chat_serving.__file__,
"OpenAIServingChat",
"chat_completion_full_generator",
)
_wrapped_chat_completion_full_generator.__module__ = OpenAIServingChat.__module__
_wrapped_chat_completion_full_generator.__qualname__ = (
f"{OpenAIServingChat.__qualname__}.chat_completion_full_generator"
)
method_source = _replace_block(
method_source,
"""\
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
usage.prompt_tokens_details = PromptTokenUsageInfo(
cached_tokens=final_res.num_cached_tokens
)
""",
"""\
reasoning_tokens = None
if reasoning_parser is not None:
reasoning_tokens = sum(
self._count_reasoning_tokens_for_usage(
as_list(output.token_ids), reasoning_parser
)
or 0
for output in final_res.outputs
)
usage = self._make_usage_info(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
num_cached_tokens=final_res.num_cached_tokens,
reasoning_tokens=reasoning_tokens,
)
""",
)
_install_method("chat_completion_full_generator", method_source)
_patch_chat_completion_stream_generator()
_patch_chat_completion_full_generator()
OpenAIServingChat.chat_completion_full_generator = _wrapped_chat_completion_full_generator