fix(platform): reimplement MiniMax usage accounting patch (#7835)
## Summary - replace the MiniMax usage accounting monkey patch with a runtime wrapper implementation instead of source-text rewriting - preserve MiniMax reasoning-token semantics when `</think>` is missing by counting the emitted output as reasoning tokens - add unit coverage for usage tracking helpers and MiniMax reasoning-token counting ## Why The previous implementation rewrote `OpenAIServingChat` by matching exact source blocks. That was brittle against `vllm` source drift and could crash during early plugin initialization with: `RuntimeError: Failed to locate expected block while patching OpenAIServingChat usage accounting.` This change keeps the usage-accounting backport, but applies it by wrapping the original stream/full generators and tracking output token ids at runtime. For MiniMax reasoning counting, a missing `</think>` should not be treated as zero reasoning tokens. It can mean the whole output is still in thinking mode, or that generation stopped before the closing token was produced. In that case, the emitted output should still be counted as reasoning. ## Validation - `pytest -q tests/ut/patch/platform/test_patch_minimax_usage_accounting.py` - `vllm serve --help` Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com> Co-authored-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from vllm.reasoning.minimax_m2_reasoning_parser import (
|
||||
MiniMaxM2AppendThinkReasoningParser,
|
||||
MiniMaxM2ReasoningParser,
|
||||
)
|
||||
|
||||
from vllm_ascend.patch.platform import (
|
||||
patch_minimax_usage_accounting as minimax_usage_patch,
|
||||
)
|
||||
|
||||
|
||||
class FakeTokenizer:
|
||||
def get_vocab(self):
|
||||
@@ -35,13 +40,13 @@ class FakeTokenizer:
|
||||
MiniMaxM2ReasoningParser,
|
||||
[10, 11, 20],
|
||||
3,
|
||||
id="minimax-all-tokens-are-reasoning-before-end-token",
|
||||
id="minimax-no-end-token-means-all-output-is-reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
MiniMaxM2AppendThinkReasoningParser,
|
||||
[10, 11, 20],
|
||||
3,
|
||||
id="append-think-all-tokens-are-reasoning-before-end-token",
|
||||
id="append-think-no-end-token-means-all-output-is-reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
MiniMaxM2ReasoningParser,
|
||||
@@ -65,3 +70,73 @@ def test_count_reasoning_tokens(
|
||||
parser = parser_cls(FakeTokenizer())
|
||||
|
||||
assert parser.count_reasoning_tokens(token_ids) == expected_reasoning_tokens
|
||||
|
||||
|
||||
def test_update_usage_tracking_state_tracks_prompt_and_completion_tokens():
|
||||
state = minimax_usage_patch._create_usage_tracking_state(
|
||||
num_choices=2,
|
||||
reasoning_parser=None,
|
||||
)
|
||||
|
||||
res = SimpleNamespace(
|
||||
prompt_token_ids=[1, 2],
|
||||
encoder_prompt_token_ids=[3],
|
||||
num_cached_tokens=4,
|
||||
outputs=[
|
||||
SimpleNamespace(index=0, token_ids=(10, 11)),
|
||||
SimpleNamespace(index=1, token_ids=[20]),
|
||||
],
|
||||
)
|
||||
|
||||
minimax_usage_patch._update_usage_tracking_state(state, res)
|
||||
|
||||
assert state.num_prompt_tokens == 3
|
||||
assert state.num_cached_tokens == 4
|
||||
assert state.completion_tokens == [2, 1]
|
||||
assert state.raw_output_token_ids == [[10, 11], [20]]
|
||||
|
||||
|
||||
def test_make_usage_info_injects_reasoning_token_details():
|
||||
fake_serving = SimpleNamespace(enable_prompt_tokens_details=True)
|
||||
usage = minimax_usage_patch._make_usage_info(
|
||||
fake_serving,
|
||||
prompt_tokens=3,
|
||||
completion_tokens=4,
|
||||
num_cached_tokens=1,
|
||||
reasoning_tokens=2,
|
||||
)
|
||||
|
||||
payload = usage.model_dump(exclude_none=True)
|
||||
|
||||
assert payload["completion_tokens_details"]["reasoning_tokens"] == 2
|
||||
assert payload["prompt_tokens_details"]["cached_tokens"] == 1
|
||||
|
||||
|
||||
def test_make_full_response_usage_sums_reasoning_tokens():
|
||||
class FakeServing:
|
||||
enable_prompt_tokens_details = False
|
||||
|
||||
def _make_usage_info(self, **kwargs):
|
||||
return minimax_usage_patch._make_usage_info(self, **kwargs)
|
||||
|
||||
class FakeReasoningParser:
|
||||
def count_reasoning_tokens(self, token_ids):
|
||||
return 2 if 2 in token_ids else 0
|
||||
|
||||
state = minimax_usage_patch._create_usage_tracking_state(
|
||||
num_choices=2,
|
||||
reasoning_parser=FakeReasoningParser(),
|
||||
)
|
||||
state.num_prompt_tokens = 3
|
||||
state.num_cached_tokens = 1
|
||||
state.final_res = SimpleNamespace(num_cached_tokens=1)
|
||||
state.completion_tokens = [4, 2]
|
||||
state.raw_output_token_ids = [[10, 11, 2, 20], [30, 31]]
|
||||
|
||||
usage = minimax_usage_patch._make_full_response_usage(FakeServing(), state)
|
||||
|
||||
assert usage.prompt_tokens == 3
|
||||
assert usage.completion_tokens == 6
|
||||
assert usage.total_tokens == 9
|
||||
assert usage.completion_tokens_details.reasoning_tokens == 2
|
||||
assert usage.prompt_tokens_details is None
|
||||
|
||||
@@ -19,10 +19,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import textwrap
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion import protocol as chat_protocol
|
||||
@@ -32,46 +30,6 @@ from vllm.entrypoints.openai.engine import protocol as engine_protocol
|
||||
from vllm.reasoning import minimax_m2_reasoning_parser as minimax_parser
|
||||
|
||||
|
||||
def _extract_class_method_source(
|
||||
module_path: str,
|
||||
class_name: str,
|
||||
method_name: str,
|
||||
) -> str:
|
||||
source = Path(module_path).read_text(encoding="utf-8")
|
||||
tree = ast.parse(source)
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
||||
for item in node.body:
|
||||
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == method_name:
|
||||
method_source = ast.get_source_segment(source, item)
|
||||
if method_source is None:
|
||||
break
|
||||
return textwrap.dedent(method_source)
|
||||
|
||||
raise RuntimeError(f"Unable to extract {class_name}.{method_name} from {module_path}.")
|
||||
|
||||
|
||||
def _install_method(method_name: str, method_source: str) -> None:
|
||||
namespace: dict[str, Any] = {}
|
||||
exec(method_source, chat_serving.__dict__, namespace)
|
||||
method = namespace[method_name]
|
||||
method.__module__ = OpenAIServingChat.__module__
|
||||
method.__qualname__ = f"{OpenAIServingChat.__qualname__}.{method_name}"
|
||||
setattr(OpenAIServingChat, method_name, method)
|
||||
|
||||
|
||||
def _replace_block(
|
||||
source: str,
|
||||
old: str,
|
||||
new: str,
|
||||
*,
|
||||
count: int = 1,
|
||||
) -> str:
|
||||
if source.count(old) < count:
|
||||
raise RuntimeError("Failed to locate expected block while patching OpenAIServingChat usage accounting.")
|
||||
return source.replace(old, new, count)
|
||||
|
||||
|
||||
class CompletionTokenUsageInfo(engine_protocol.OpenAIBaseModel):
|
||||
reasoning_tokens: int | None = None
|
||||
audio_tokens: int | None = None
|
||||
@@ -160,208 +118,132 @@ OpenAIServingChat._count_reasoning_tokens_for_usage = staticmethod(_count_reason
|
||||
OpenAIServingChat._make_usage_info = _make_usage_info
|
||||
|
||||
|
||||
def _patch_chat_completion_stream_generator() -> None:
|
||||
method_source = _extract_class_method_source(
|
||||
chat_serving.__file__,
|
||||
"OpenAIServingChat",
|
||||
"chat_completion_stream_generator",
|
||||
@dataclass
|
||||
class _UsageTrackingState:
|
||||
completion_tokens: list[int]
|
||||
raw_output_token_ids: list[list[int]]
|
||||
reasoning_parser: Any
|
||||
num_prompt_tokens: int = 0
|
||||
num_cached_tokens: int | None = None
|
||||
final_res: Any = None
|
||||
|
||||
|
||||
def _create_usage_tracking_state(
|
||||
num_choices: int,
|
||||
reasoning_parser,
|
||||
) -> _UsageTrackingState:
|
||||
return _UsageTrackingState(
|
||||
completion_tokens=[0] * num_choices,
|
||||
raw_output_token_ids=[[] for _ in range(num_choices)],
|
||||
reasoning_parser=reasoning_parser,
|
||||
)
|
||||
|
||||
method_source = _replace_block(
|
||||
method_source,
|
||||
"""\
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
""",
|
||||
"""\
|
||||
previous_num_tokens = [0] * num_choices
|
||||
raw_output_token_ids = [[] for _ in range(num_choices)]
|
||||
finish_reason_sent = [False] * num_choices
|
||||
""",
|
||||
|
||||
def _update_usage_tracking_state(
|
||||
state: _UsageTrackingState,
|
||||
res,
|
||||
) -> None:
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if res.encoder_prompt_token_ids is not None:
|
||||
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
||||
state.num_prompt_tokens = num_prompt_tokens
|
||||
|
||||
if state.num_cached_tokens is None:
|
||||
state.num_cached_tokens = res.num_cached_tokens
|
||||
|
||||
state.final_res = res
|
||||
|
||||
for output in res.outputs:
|
||||
if 0 <= output.index < len(state.completion_tokens):
|
||||
token_ids = chat_serving.as_list(output.token_ids)
|
||||
state.completion_tokens[output.index] += len(token_ids)
|
||||
state.raw_output_token_ids[output.index].extend(token_ids)
|
||||
|
||||
|
||||
async def _tracked_result_generator(
|
||||
result_generator: AsyncIterator,
|
||||
state: _UsageTrackingState,
|
||||
):
|
||||
async for res in result_generator:
|
||||
_update_usage_tracking_state(state, res)
|
||||
yield res
|
||||
|
||||
|
||||
def _sum_reasoning_tokens_for_usage(
|
||||
raw_output_token_ids: list[list[int]],
|
||||
reasoning_parser,
|
||||
) -> int | None:
|
||||
if reasoning_parser is None:
|
||||
return None
|
||||
return sum(
|
||||
_count_reasoning_tokens_for_usage(token_ids, reasoning_parser) or 0 for token_ids in raw_output_token_ids
|
||||
)
|
||||
|
||||
method_source = _replace_block(
|
||||
method_source,
|
||||
"""\
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
""",
|
||||
"""\
|
||||
if include_continuous_usage:
|
||||
chunk.usage = self._make_usage_info(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
reasoning_tokens=self._count_reasoning_tokens_for_usage(
|
||||
raw_output_token_ids[i], reasoning_parser
|
||||
),
|
||||
)
|
||||
""",
|
||||
|
||||
def _make_full_response_usage(
|
||||
self,
|
||||
state: _UsageTrackingState,
|
||||
) -> UsageInfo | None:
|
||||
if state.final_res is None:
|
||||
return None
|
||||
|
||||
return self._make_usage_info(
|
||||
prompt_tokens=state.num_prompt_tokens,
|
||||
completion_tokens=sum(state.completion_tokens),
|
||||
num_cached_tokens=state.num_cached_tokens,
|
||||
reasoning_tokens=_sum_reasoning_tokens_for_usage(
|
||||
state.raw_output_token_ids,
|
||||
state.reasoning_parser,
|
||||
),
|
||||
)
|
||||
|
||||
method_source = _replace_block(
|
||||
method_source,
|
||||
"""\
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
""",
|
||||
"""\
|
||||
if include_continuous_usage:
|
||||
chunk.usage = self._make_usage_info(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
reasoning_tokens=self._count_reasoning_tokens_for_usage(
|
||||
raw_output_token_ids[i], reasoning_parser
|
||||
),
|
||||
)
|
||||
""",
|
||||
|
||||
if not hasattr(OpenAIServingChat, "_ascend_original_chat_completion_full_generator"):
|
||||
OpenAIServingChat._ascend_original_chat_completion_full_generator = OpenAIServingChat.chat_completion_full_generator
|
||||
|
||||
|
||||
async def _wrapped_chat_completion_full_generator(
|
||||
self,
|
||||
request: chat_protocol.ChatCompletionRequest,
|
||||
result_generator: AsyncIterator,
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation,
|
||||
tokenizer,
|
||||
request_metadata: engine_protocol.RequestResponseMetadata,
|
||||
reasoning_parser=None,
|
||||
):
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
state = _create_usage_tracking_state(num_choices, reasoning_parser)
|
||||
|
||||
original_full_generator = self._ascend_original_chat_completion_full_generator
|
||||
response = await original_full_generator(
|
||||
request,
|
||||
_tracked_result_generator(result_generator, state),
|
||||
request_id,
|
||||
model_name,
|
||||
conversation,
|
||||
tokenizer,
|
||||
request_metadata,
|
||||
reasoning_parser,
|
||||
)
|
||||
|
||||
method_source = _replace_block(
|
||||
method_source,
|
||||
"""\
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
""",
|
||||
"""\
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
raw_output_token_ids[i].extend(as_list(output.token_ids))
|
||||
""",
|
||||
)
|
||||
if not isinstance(response, chat_protocol.ChatCompletionResponse):
|
||||
return response
|
||||
|
||||
method_source = _replace_block(
|
||||
method_source,
|
||||
"""\
|
||||
if include_continuous_usage:
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
""",
|
||||
"""\
|
||||
if include_continuous_usage:
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
chunk.usage = self._make_usage_info(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
reasoning_tokens=self._count_reasoning_tokens_for_usage(
|
||||
raw_output_token_ids[i], reasoning_parser
|
||||
),
|
||||
)
|
||||
""",
|
||||
)
|
||||
usage = _make_full_response_usage(self, state)
|
||||
if usage is None:
|
||||
return response
|
||||
|
||||
method_source = _replace_block(
|
||||
method_source,
|
||||
"""\
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
if self.enable_prompt_tokens_details and num_cached_tokens:
|
||||
final_usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=num_cached_tokens
|
||||
)
|
||||
""",
|
||||
"""\
|
||||
reasoning_tokens = None
|
||||
if reasoning_parser is not None:
|
||||
reasoning_tokens = sum(
|
||||
self._count_reasoning_tokens_for_usage(
|
||||
token_ids, reasoning_parser
|
||||
)
|
||||
or 0
|
||||
for token_ids in raw_output_token_ids
|
||||
)
|
||||
final_usage = self._make_usage_info(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
method_source = _replace_block(
|
||||
method_source,
|
||||
"""\
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_completion_tokens,
|
||||
total_tokens=num_prompt_tokens + num_completion_tokens,
|
||||
)
|
||||
""",
|
||||
"""\
|
||||
reasoning_tokens = None
|
||||
if reasoning_parser is not None:
|
||||
reasoning_tokens = sum(
|
||||
self._count_reasoning_tokens_for_usage(
|
||||
token_ids, reasoning_parser
|
||||
)
|
||||
or 0
|
||||
for token_ids in raw_output_token_ids
|
||||
)
|
||||
request_metadata.final_usage_info = self._make_usage_info(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_completion_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_install_method("chat_completion_stream_generator", method_source)
|
||||
response.usage = usage
|
||||
request_metadata.final_usage_info = usage
|
||||
return response
|
||||
|
||||
|
||||
def _patch_chat_completion_full_generator() -> None:
|
||||
method_source = _extract_class_method_source(
|
||||
chat_serving.__file__,
|
||||
"OpenAIServingChat",
|
||||
"chat_completion_full_generator",
|
||||
)
|
||||
_wrapped_chat_completion_full_generator.__module__ = OpenAIServingChat.__module__
|
||||
_wrapped_chat_completion_full_generator.__qualname__ = (
|
||||
f"{OpenAIServingChat.__qualname__}.chat_completion_full_generator"
|
||||
)
|
||||
|
||||
method_source = _replace_block(
|
||||
method_source,
|
||||
"""\
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
if self.enable_prompt_tokens_details and final_res.num_cached_tokens:
|
||||
usage.prompt_tokens_details = PromptTokenUsageInfo(
|
||||
cached_tokens=final_res.num_cached_tokens
|
||||
)
|
||||
""",
|
||||
"""\
|
||||
reasoning_tokens = None
|
||||
if reasoning_parser is not None:
|
||||
reasoning_tokens = sum(
|
||||
self._count_reasoning_tokens_for_usage(
|
||||
as_list(output.token_ids), reasoning_parser
|
||||
)
|
||||
or 0
|
||||
for output in final_res.outputs
|
||||
)
|
||||
usage = self._make_usage_info(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
num_cached_tokens=final_res.num_cached_tokens,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_install_method("chat_completion_full_generator", method_source)
|
||||
|
||||
|
||||
_patch_chat_completion_stream_generator()
|
||||
_patch_chat_completion_full_generator()
|
||||
OpenAIServingChat.chat_completion_full_generator = _wrapped_chat_completion_full_generator
|
||||
|
||||
Reference in New Issue
Block a user