Files
xc-llm-ascend/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py
jack 7314bbe2df fix(platform): reimplement MiniMax usage accounting patch (#7835)
## Summary
- replace the MiniMax usage accounting monkey patch with a runtime
wrapper implementation instead of source-text rewriting
- preserve MiniMax reasoning-token semantics when `</think>` is missing
by counting the emitted output as reasoning tokens
- add unit coverage for usage tracking helpers and MiniMax
reasoning-token counting

## Why
The previous implementation rewrote `OpenAIServingChat` by matching
exact source blocks. That was brittle against `vllm` source drift and
could crash during early plugin initialization with:
`RuntimeError: Failed to locate expected block while patching
OpenAIServingChat usage accounting.`

This change keeps the usage-accounting backport, but applies it by
wrapping the original stream/full generators and tracking output token
ids at runtime.

For MiniMax reasoning counting, a missing `</think>` should not be
treated as zero reasoning tokens. It can mean the whole output is still
in thinking mode, or that generation stopped before the closing token
was produced. In that case, the emitted output should still be counted
as reasoning.

## Validation
- `pytest -q
tests/ut/patch/platform/test_patch_minimax_usage_accounting.py`
- `vllm serve --help`

Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
Co-authored-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
2026-03-31 16:27:00 +08:00

250 lines
8.0 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# OpenAI chat usage accounting: backport MiniMax reasoning token accounting.
#
from __future__ import annotations
from collections.abc import AsyncIterator, Sequence
from dataclasses import dataclass
from typing import Any
from vllm.entrypoints.openai.chat_completion import protocol as chat_protocol
from vllm.entrypoints.openai.chat_completion import serving as chat_serving
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine import protocol as engine_protocol
from vllm.reasoning import minimax_m2_reasoning_parser as minimax_parser
class CompletionTokenUsageInfo(engine_protocol.OpenAIBaseModel):
reasoning_tokens: int | None = None
audio_tokens: int | None = None
accepted_prediction_tokens: int | None = None
rejected_prediction_tokens: int | None = None
class UsageInfo(engine_protocol.UsageInfo):
completion_tokens_details: CompletionTokenUsageInfo | None = None
CompletionTokenUsageInfo.__module__ = engine_protocol.__name__
UsageInfo.__module__ = engine_protocol.__name__
engine_protocol.CompletionTokenUsageInfo = CompletionTokenUsageInfo
engine_protocol.UsageInfo = UsageInfo
chat_protocol.UsageInfo = UsageInfo
chat_serving.CompletionTokenUsageInfo = CompletionTokenUsageInfo
chat_serving.UsageInfo = UsageInfo
def _rebuild_model_field(model_cls, field_name: str, annotation) -> None:
model_cls.__annotations__[field_name] = annotation
model_cls.model_fields[field_name].annotation = annotation
model_cls.model_rebuild(force=True)
_rebuild_model_field(chat_protocol.ChatCompletionResponse, "usage", UsageInfo)
_rebuild_model_field(chat_protocol.ChatCompletionStreamResponse, "usage", UsageInfo | None)
_rebuild_model_field(engine_protocol.RequestResponseMetadata, "final_usage_info", UsageInfo | None)
def _count_minimax_reasoning_tokens(
token_ids: Sequence[int],
end_token_id: int | None,
) -> int:
if end_token_id is None:
return 0
for idx, token_id in enumerate(token_ids):
if token_id == end_token_id:
return idx
return len(token_ids)
def _patched_count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
return _count_minimax_reasoning_tokens(token_ids, self.end_token_id)
minimax_parser.MiniMaxM2ReasoningParser.count_reasoning_tokens = _patched_count_reasoning_tokens
minimax_parser.MiniMaxM2AppendThinkReasoningParser.count_reasoning_tokens = _patched_count_reasoning_tokens
def _count_reasoning_tokens_for_usage(
token_ids: Sequence[int],
reasoning_parser,
) -> int | None:
if reasoning_parser is None:
return None
return reasoning_parser.count_reasoning_tokens(token_ids)
def _make_usage_info(
self,
*,
prompt_tokens: int,
completion_tokens: int,
num_cached_tokens: int | None = None,
reasoning_tokens: int | None = None,
) -> UsageInfo:
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if reasoning_tokens is not None:
usage.completion_tokens_details = CompletionTokenUsageInfo(
reasoning_tokens=max(0, min(reasoning_tokens, completion_tokens))
)
if self.enable_prompt_tokens_details and num_cached_tokens:
usage.prompt_tokens_details = chat_serving.PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
return usage
OpenAIServingChat._count_reasoning_tokens_for_usage = staticmethod(_count_reasoning_tokens_for_usage)
OpenAIServingChat._make_usage_info = _make_usage_info
@dataclass
class _UsageTrackingState:
completion_tokens: list[int]
raw_output_token_ids: list[list[int]]
reasoning_parser: Any
num_prompt_tokens: int = 0
num_cached_tokens: int | None = None
final_res: Any = None
def _create_usage_tracking_state(
num_choices: int,
reasoning_parser,
) -> _UsageTrackingState:
return _UsageTrackingState(
completion_tokens=[0] * num_choices,
raw_output_token_ids=[[] for _ in range(num_choices)],
reasoning_parser=reasoning_parser,
)
def _update_usage_tracking_state(
state: _UsageTrackingState,
res,
) -> None:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids)
state.num_prompt_tokens = num_prompt_tokens
if state.num_cached_tokens is None:
state.num_cached_tokens = res.num_cached_tokens
state.final_res = res
for output in res.outputs:
if 0 <= output.index < len(state.completion_tokens):
token_ids = chat_serving.as_list(output.token_ids)
state.completion_tokens[output.index] += len(token_ids)
state.raw_output_token_ids[output.index].extend(token_ids)
async def _tracked_result_generator(
result_generator: AsyncIterator,
state: _UsageTrackingState,
):
async for res in result_generator:
_update_usage_tracking_state(state, res)
yield res
def _sum_reasoning_tokens_for_usage(
raw_output_token_ids: list[list[int]],
reasoning_parser,
) -> int | None:
if reasoning_parser is None:
return None
return sum(
_count_reasoning_tokens_for_usage(token_ids, reasoning_parser) or 0 for token_ids in raw_output_token_ids
)
def _make_full_response_usage(
self,
state: _UsageTrackingState,
) -> UsageInfo | None:
if state.final_res is None:
return None
return self._make_usage_info(
prompt_tokens=state.num_prompt_tokens,
completion_tokens=sum(state.completion_tokens),
num_cached_tokens=state.num_cached_tokens,
reasoning_tokens=_sum_reasoning_tokens_for_usage(
state.raw_output_token_ids,
state.reasoning_parser,
),
)
if not hasattr(OpenAIServingChat, "_ascend_original_chat_completion_full_generator"):
OpenAIServingChat._ascend_original_chat_completion_full_generator = OpenAIServingChat.chat_completion_full_generator
async def _wrapped_chat_completion_full_generator(
self,
request: chat_protocol.ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
model_name: str,
conversation,
tokenizer,
request_metadata: engine_protocol.RequestResponseMetadata,
reasoning_parser=None,
):
num_choices = 1 if request.n is None else request.n
state = _create_usage_tracking_state(num_choices, reasoning_parser)
original_full_generator = self._ascend_original_chat_completion_full_generator
response = await original_full_generator(
request,
_tracked_result_generator(result_generator, state),
request_id,
model_name,
conversation,
tokenizer,
request_metadata,
reasoning_parser,
)
if not isinstance(response, chat_protocol.ChatCompletionResponse):
return response
usage = _make_full_response_usage(self, state)
if usage is None:
return response
response.usage = usage
request_metadata.final_usage_info = usage
return response
_wrapped_chat_completion_full_generator.__module__ = OpenAIServingChat.__module__
_wrapped_chat_completion_full_generator.__qualname__ = (
f"{OpenAIServingChat.__qualname__}.chat_completion_full_generator"
)
OpenAIServingChat.chat_completion_full_generator = _wrapped_chat_completion_full_generator