## Summary - replace the MiniMax usage accounting monkey patch with a runtime wrapper implementation instead of source-text rewriting - preserve MiniMax reasoning-token semantics when `</think>` is missing by counting the emitted output as reasoning tokens - add unit coverage for usage tracking helpers and MiniMax reasoning-token counting ## Why The previous implementation rewrote `OpenAIServingChat` by matching exact source blocks. That was brittle against `vllm` source drift and could crash during early plugin initialization with: `RuntimeError: Failed to locate expected block while patching OpenAIServingChat usage accounting.` This change keeps the usage-accounting backport, but applies it by wrapping the original stream/full generators and tracking output token ids at runtime. For MiniMax reasoning counting, a missing `</think>` should not be treated as zero reasoning tokens. It can mean the whole output is still in thinking mode, or that generation stopped before the closing token was produced. In that case, the emitted output should still be counted as reasoning. ## Validation - `pytest -q tests/ut/patch/platform/test_patch_minimax_usage_accounting.py` - `vllm serve --help` Signed-off-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com> Co-authored-by: QwertyJack <7554089+QwertyJack@users.noreply.github.com>
250 lines
8.0 KiB
Python
250 lines
8.0 KiB
Python
#
|
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# OpenAI chat usage accounting: backport MiniMax reasoning token accounting.
|
|
#
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator, Sequence
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from vllm.entrypoints.openai.chat_completion import protocol as chat_protocol
|
|
from vllm.entrypoints.openai.chat_completion import serving as chat_serving
|
|
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
|
|
from vllm.entrypoints.openai.engine import protocol as engine_protocol
|
|
from vllm.reasoning import minimax_m2_reasoning_parser as minimax_parser
|
|
|
|
|
|
class CompletionTokenUsageInfo(engine_protocol.OpenAIBaseModel):
|
|
reasoning_tokens: int | None = None
|
|
audio_tokens: int | None = None
|
|
accepted_prediction_tokens: int | None = None
|
|
rejected_prediction_tokens: int | None = None
|
|
|
|
|
|
class UsageInfo(engine_protocol.UsageInfo):
|
|
completion_tokens_details: CompletionTokenUsageInfo | None = None
|
|
|
|
|
|
CompletionTokenUsageInfo.__module__ = engine_protocol.__name__
|
|
UsageInfo.__module__ = engine_protocol.__name__
|
|
|
|
engine_protocol.CompletionTokenUsageInfo = CompletionTokenUsageInfo
|
|
engine_protocol.UsageInfo = UsageInfo
|
|
chat_protocol.UsageInfo = UsageInfo
|
|
chat_serving.CompletionTokenUsageInfo = CompletionTokenUsageInfo
|
|
chat_serving.UsageInfo = UsageInfo
|
|
|
|
|
|
def _rebuild_model_field(model_cls, field_name: str, annotation) -> None:
|
|
model_cls.__annotations__[field_name] = annotation
|
|
model_cls.model_fields[field_name].annotation = annotation
|
|
model_cls.model_rebuild(force=True)
|
|
|
|
|
|
_rebuild_model_field(chat_protocol.ChatCompletionResponse, "usage", UsageInfo)
|
|
_rebuild_model_field(chat_protocol.ChatCompletionStreamResponse, "usage", UsageInfo | None)
|
|
_rebuild_model_field(engine_protocol.RequestResponseMetadata, "final_usage_info", UsageInfo | None)
|
|
|
|
|
|
def _count_minimax_reasoning_tokens(
|
|
token_ids: Sequence[int],
|
|
end_token_id: int | None,
|
|
) -> int:
|
|
if end_token_id is None:
|
|
return 0
|
|
|
|
for idx, token_id in enumerate(token_ids):
|
|
if token_id == end_token_id:
|
|
return idx
|
|
return len(token_ids)
|
|
|
|
|
|
def _patched_count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
|
|
return _count_minimax_reasoning_tokens(token_ids, self.end_token_id)
|
|
|
|
|
|
minimax_parser.MiniMaxM2ReasoningParser.count_reasoning_tokens = _patched_count_reasoning_tokens
|
|
minimax_parser.MiniMaxM2AppendThinkReasoningParser.count_reasoning_tokens = _patched_count_reasoning_tokens
|
|
|
|
|
|
def _count_reasoning_tokens_for_usage(
|
|
token_ids: Sequence[int],
|
|
reasoning_parser,
|
|
) -> int | None:
|
|
if reasoning_parser is None:
|
|
return None
|
|
return reasoning_parser.count_reasoning_tokens(token_ids)
|
|
|
|
|
|
def _make_usage_info(
|
|
self,
|
|
*,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
num_cached_tokens: int | None = None,
|
|
reasoning_tokens: int | None = None,
|
|
) -> UsageInfo:
|
|
usage = UsageInfo(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
)
|
|
if reasoning_tokens is not None:
|
|
usage.completion_tokens_details = CompletionTokenUsageInfo(
|
|
reasoning_tokens=max(0, min(reasoning_tokens, completion_tokens))
|
|
)
|
|
if self.enable_prompt_tokens_details and num_cached_tokens:
|
|
usage.prompt_tokens_details = chat_serving.PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
|
|
return usage
|
|
|
|
|
|
OpenAIServingChat._count_reasoning_tokens_for_usage = staticmethod(_count_reasoning_tokens_for_usage)
|
|
OpenAIServingChat._make_usage_info = _make_usage_info
|
|
|
|
|
|
@dataclass
|
|
class _UsageTrackingState:
|
|
completion_tokens: list[int]
|
|
raw_output_token_ids: list[list[int]]
|
|
reasoning_parser: Any
|
|
num_prompt_tokens: int = 0
|
|
num_cached_tokens: int | None = None
|
|
final_res: Any = None
|
|
|
|
|
|
def _create_usage_tracking_state(
|
|
num_choices: int,
|
|
reasoning_parser,
|
|
) -> _UsageTrackingState:
|
|
return _UsageTrackingState(
|
|
completion_tokens=[0] * num_choices,
|
|
raw_output_token_ids=[[] for _ in range(num_choices)],
|
|
reasoning_parser=reasoning_parser,
|
|
)
|
|
|
|
|
|
def _update_usage_tracking_state(
|
|
state: _UsageTrackingState,
|
|
res,
|
|
) -> None:
|
|
if res.prompt_token_ids is not None:
|
|
num_prompt_tokens = len(res.prompt_token_ids)
|
|
if res.encoder_prompt_token_ids is not None:
|
|
num_prompt_tokens += len(res.encoder_prompt_token_ids)
|
|
state.num_prompt_tokens = num_prompt_tokens
|
|
|
|
if state.num_cached_tokens is None:
|
|
state.num_cached_tokens = res.num_cached_tokens
|
|
|
|
state.final_res = res
|
|
|
|
for output in res.outputs:
|
|
if 0 <= output.index < len(state.completion_tokens):
|
|
token_ids = chat_serving.as_list(output.token_ids)
|
|
state.completion_tokens[output.index] += len(token_ids)
|
|
state.raw_output_token_ids[output.index].extend(token_ids)
|
|
|
|
|
|
async def _tracked_result_generator(
|
|
result_generator: AsyncIterator,
|
|
state: _UsageTrackingState,
|
|
):
|
|
async for res in result_generator:
|
|
_update_usage_tracking_state(state, res)
|
|
yield res
|
|
|
|
|
|
def _sum_reasoning_tokens_for_usage(
|
|
raw_output_token_ids: list[list[int]],
|
|
reasoning_parser,
|
|
) -> int | None:
|
|
if reasoning_parser is None:
|
|
return None
|
|
return sum(
|
|
_count_reasoning_tokens_for_usage(token_ids, reasoning_parser) or 0 for token_ids in raw_output_token_ids
|
|
)
|
|
|
|
|
|
def _make_full_response_usage(
|
|
self,
|
|
state: _UsageTrackingState,
|
|
) -> UsageInfo | None:
|
|
if state.final_res is None:
|
|
return None
|
|
|
|
return self._make_usage_info(
|
|
prompt_tokens=state.num_prompt_tokens,
|
|
completion_tokens=sum(state.completion_tokens),
|
|
num_cached_tokens=state.num_cached_tokens,
|
|
reasoning_tokens=_sum_reasoning_tokens_for_usage(
|
|
state.raw_output_token_ids,
|
|
state.reasoning_parser,
|
|
),
|
|
)
|
|
|
|
|
|
if not hasattr(OpenAIServingChat, "_ascend_original_chat_completion_full_generator"):
|
|
OpenAIServingChat._ascend_original_chat_completion_full_generator = OpenAIServingChat.chat_completion_full_generator
|
|
|
|
|
|
async def _wrapped_chat_completion_full_generator(
|
|
self,
|
|
request: chat_protocol.ChatCompletionRequest,
|
|
result_generator: AsyncIterator,
|
|
request_id: str,
|
|
model_name: str,
|
|
conversation,
|
|
tokenizer,
|
|
request_metadata: engine_protocol.RequestResponseMetadata,
|
|
reasoning_parser=None,
|
|
):
|
|
num_choices = 1 if request.n is None else request.n
|
|
state = _create_usage_tracking_state(num_choices, reasoning_parser)
|
|
|
|
original_full_generator = self._ascend_original_chat_completion_full_generator
|
|
response = await original_full_generator(
|
|
request,
|
|
_tracked_result_generator(result_generator, state),
|
|
request_id,
|
|
model_name,
|
|
conversation,
|
|
tokenizer,
|
|
request_metadata,
|
|
reasoning_parser,
|
|
)
|
|
|
|
if not isinstance(response, chat_protocol.ChatCompletionResponse):
|
|
return response
|
|
|
|
usage = _make_full_response_usage(self, state)
|
|
if usage is None:
|
|
return response
|
|
|
|
response.usage = usage
|
|
request_metadata.final_usage_info = usage
|
|
return response
|
|
|
|
|
|
_wrapped_chat_completion_full_generator.__module__ = OpenAIServingChat.__module__
|
|
_wrapped_chat_completion_full_generator.__qualname__ = (
|
|
f"{OpenAIServingChat.__qualname__}.chat_completion_full_generator"
|
|
)
|
|
|
|
OpenAIServingChat.chat_completion_full_generator = _wrapped_chat_completion_full_generator
|