diff --git a/tests/ut/patch/platform/test_patch_minimax_usage_accounting.py b/tests/ut/patch/platform/test_patch_minimax_usage_accounting.py new file mode 100644 index 00000000..3456687d --- /dev/null +++ b/tests/ut/patch/platform/test_patch_minimax_usage_accounting.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from vllm.reasoning.minimax_m2_reasoning_parser import ( + MiniMaxM2AppendThinkReasoningParser, + MiniMaxM2ReasoningParser, +) + + +class FakeTokenizer: + def get_vocab(self): + return { + "": 1, + "": 2, + } + + +@pytest.mark.parametrize( + ("parser_cls", "token_ids", "expected_reasoning_tokens"), + [ + pytest.param( + MiniMaxM2ReasoningParser, + [10, 11, 2, 20], + 2, + id="minimax-reasoning-before-end-token", + ), + pytest.param( + MiniMaxM2AppendThinkReasoningParser, + [10, 11, 2, 20], + 2, + id="append-think-reasoning-before-end-token", + ), + pytest.param( + MiniMaxM2ReasoningParser, + [10, 11, 20], + 3, + id="minimax-all-tokens-are-reasoning-before-end-token", + ), + pytest.param( + MiniMaxM2AppendThinkReasoningParser, + [10, 11, 20], + 3, + id="append-think-all-tokens-are-reasoning-before-end-token", + ), + pytest.param( + MiniMaxM2ReasoningParser, + [2, 20], + 0, + id="minimax-end-token-first-means-no-reasoning-tokens", + ), + pytest.param( + MiniMaxM2AppendThinkReasoningParser, + [2, 20], + 0, + id="append-think-end-token-first-means-no-reasoning-tokens", + ), + ], +) +def test_count_reasoning_tokens( + parser_cls, + token_ids, + expected_reasoning_tokens, +): + parser = parser_cls(FakeTokenizer()) + + assert parser.count_reasoning_tokens(token_ids) == expected_reasoning_tokens diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 1b820631..7bc9c545 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -159,6 +159,26 @@ # Future Plan: # Remove this patch after the upcoming KV cache spec refactor. # +# ** 9. File: platform/patch_minimax_usage_accounting.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.entrypoints.openai.chat_completion.serving.OpenAIServingChat` +# `vllm.entrypoints.openai.engine.protocol.UsageInfo` +# `vllm.reasoning.minimax_m2_reasoning_parser` +# Why: +# MiniMax M2 reasoning outputs use `` as the only boundary token, +# but the runtime usage accounting path either omits reasoning token +# details entirely or counts them incorrectly. +# How: +# Monkey-patch the MiniMax reasoning token counters, extend `UsageInfo` +# with `completion_tokens_details.reasoning_tokens`, and update chat +# streaming/non-streaming usage generation to propagate the corrected +# counts. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/37955 +# Future Plan: +# Remove this patch once the upstream MiniMax usage-accounting fix is in +# the runtime vLLM version used by vllm-ascend. +# # * Worker Patch: # =============== # diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 28d87efb..4f7771ed 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -29,6 +29,7 @@ else: import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.platform.patch_torch_accelerator # noqa +import vllm_ascend.patch.platform.patch_minimax_usage_accounting # noqa if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true": import vllm_ascend.patch.platform.patch_multiproc_executor # noqa diff --git a/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py b/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py new file mode 100644 index 00000000..6c416cb3 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py @@ -0,0 +1,367 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# OpenAI chat usage accounting: backport MiniMax reasoning token accounting. +# + +from __future__ import annotations + +import ast +import textwrap +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +from vllm.entrypoints.openai.chat_completion import protocol as chat_protocol +from vllm.entrypoints.openai.chat_completion import serving as chat_serving +from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat +from vllm.entrypoints.openai.engine import protocol as engine_protocol +from vllm.reasoning import minimax_m2_reasoning_parser as minimax_parser + + +def _extract_class_method_source( + module_path: str, + class_name: str, + method_name: str, +) -> str: + source = Path(module_path).read_text(encoding="utf-8") + tree = ast.parse(source) + for node in tree.body: + if isinstance(node, ast.ClassDef) and node.name == class_name: + for item in node.body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == method_name: + method_source = ast.get_source_segment(source, item) + if method_source is None: + break + return textwrap.dedent(method_source) + + raise RuntimeError(f"Unable to extract {class_name}.{method_name} from {module_path}.") + + +def _install_method(method_name: str, method_source: str) -> None: + namespace: dict[str, Any] = {} + exec(method_source, chat_serving.__dict__, namespace) + method = namespace[method_name] + method.__module__ = OpenAIServingChat.__module__ + method.__qualname__ = f"{OpenAIServingChat.__qualname__}.{method_name}" + setattr(OpenAIServingChat, method_name, method) + + +def _replace_block( + source: str, + old: str, + new: str, + *, + count: int = 1, +) -> str: + if source.count(old) < count: + raise RuntimeError("Failed to locate expected block while patching OpenAIServingChat usage accounting.") + return source.replace(old, new, count) + + +class CompletionTokenUsageInfo(engine_protocol.OpenAIBaseModel): + reasoning_tokens: int | None = None + audio_tokens: int | None = None + accepted_prediction_tokens: int | None = None + rejected_prediction_tokens: int | None = None + + +class UsageInfo(engine_protocol.UsageInfo): + completion_tokens_details: CompletionTokenUsageInfo | None = None + + +CompletionTokenUsageInfo.__module__ = engine_protocol.__name__ +UsageInfo.__module__ = engine_protocol.__name__ + +engine_protocol.CompletionTokenUsageInfo = CompletionTokenUsageInfo +engine_protocol.UsageInfo = UsageInfo +chat_protocol.UsageInfo = UsageInfo +chat_serving.CompletionTokenUsageInfo = CompletionTokenUsageInfo +chat_serving.UsageInfo = UsageInfo + + +def _rebuild_model_field(model_cls, field_name: str, annotation) -> None: + model_cls.__annotations__[field_name] = annotation + model_cls.model_fields[field_name].annotation = annotation + model_cls.model_rebuild(force=True) + + +_rebuild_model_field(chat_protocol.ChatCompletionResponse, "usage", UsageInfo) +_rebuild_model_field(chat_protocol.ChatCompletionStreamResponse, "usage", UsageInfo | None) +_rebuild_model_field(engine_protocol.RequestResponseMetadata, "final_usage_info", UsageInfo | None) + + +def _count_minimax_reasoning_tokens( + token_ids: Sequence[int], + end_token_id: int | None, +) -> int: + if end_token_id is None: + return 0 + + for idx, token_id in enumerate(token_ids): + if token_id == end_token_id: + return idx + return len(token_ids) + + +def _patched_count_reasoning_tokens(self, token_ids: Sequence[int]) -> int: + return _count_minimax_reasoning_tokens(token_ids, self.end_token_id) + + +minimax_parser.MiniMaxM2ReasoningParser.count_reasoning_tokens = _patched_count_reasoning_tokens +minimax_parser.MiniMaxM2AppendThinkReasoningParser.count_reasoning_tokens = _patched_count_reasoning_tokens + + +def _count_reasoning_tokens_for_usage( + token_ids: Sequence[int], + reasoning_parser, +) -> int | None: + if reasoning_parser is None: + return None + return reasoning_parser.count_reasoning_tokens(token_ids) + + +def _make_usage_info( + self, + *, + prompt_tokens: int, + completion_tokens: int, + num_cached_tokens: int | None = None, + reasoning_tokens: int | None = None, +) -> UsageInfo: + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + if reasoning_tokens is not None: + usage.completion_tokens_details = CompletionTokenUsageInfo( + reasoning_tokens=max(0, min(reasoning_tokens, completion_tokens)) + ) + if self.enable_prompt_tokens_details and num_cached_tokens: + usage.prompt_tokens_details = chat_serving.PromptTokenUsageInfo(cached_tokens=num_cached_tokens) + return usage + + +OpenAIServingChat._count_reasoning_tokens_for_usage = staticmethod(_count_reasoning_tokens_for_usage) +OpenAIServingChat._make_usage_info = _make_usage_info + + +def _patch_chat_completion_stream_generator() -> None: + method_source = _extract_class_method_source( + chat_serving.__file__, + "OpenAIServingChat", + "chat_completion_stream_generator", + ) + + method_source = _replace_block( + method_source, + """\ + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices +""", + """\ + previous_num_tokens = [0] * num_choices + raw_output_token_ids = [[] for _ in range(num_choices)] + finish_reason_sent = [False] * num_choices +""", + ) + + method_source = _replace_block( + method_source, + """\ + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) +""", + """\ + if include_continuous_usage: + chunk.usage = self._make_usage_info( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + reasoning_tokens=self._count_reasoning_tokens_for_usage( + raw_output_token_ids[i], reasoning_parser + ), + ) +""", + ) + + method_source = _replace_block( + method_source, + """\ + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens, + ) +""", + """\ + if include_continuous_usage: + chunk.usage = self._make_usage_info( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + reasoning_tokens=self._count_reasoning_tokens_for_usage( + raw_output_token_ids[i], reasoning_parser + ), + ) +""", + ) + + method_source = _replace_block( + method_source, + """\ + previous_num_tokens[i] += len(output.token_ids) +""", + """\ + previous_num_tokens[i] += len(output.token_ids) + raw_output_token_ids[i].extend(as_list(output.token_ids)) +""", + ) + + method_source = _replace_block( + method_source, + """\ + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) +""", + """\ + if include_continuous_usage: + completion_tokens = previous_num_tokens[i] + chunk.usage = self._make_usage_info( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + reasoning_tokens=self._count_reasoning_tokens_for_usage( + raw_output_token_ids[i], reasoning_parser + ), + ) +""", + ) + + method_source = _replace_block( + method_source, + """\ + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens + ) +""", + """\ + reasoning_tokens = None + if reasoning_parser is not None: + reasoning_tokens = sum( + self._count_reasoning_tokens_for_usage( + token_ids, reasoning_parser + ) + or 0 + for token_ids in raw_output_token_ids + ) + final_usage = self._make_usage_info( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + num_cached_tokens=num_cached_tokens, + reasoning_tokens=reasoning_tokens, + ) +""", + ) + + method_source = _replace_block( + method_source, + """\ + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens, + ) +""", + """\ + reasoning_tokens = None + if reasoning_parser is not None: + reasoning_tokens = sum( + self._count_reasoning_tokens_for_usage( + token_ids, reasoning_parser + ) + or 0 + for token_ids in raw_output_token_ids + ) + request_metadata.final_usage_info = self._make_usage_info( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + reasoning_tokens=reasoning_tokens, + ) +""", + ) + + _install_method("chat_completion_stream_generator", method_source) + + +def _patch_chat_completion_full_generator() -> None: + method_source = _extract_class_method_source( + chat_serving.__file__, + "OpenAIServingChat", + "chat_completion_full_generator", + ) + + method_source = _replace_block( + method_source, + """\ + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens + ) +""", + """\ + reasoning_tokens = None + if reasoning_parser is not None: + reasoning_tokens = sum( + self._count_reasoning_tokens_for_usage( + as_list(output.token_ids), reasoning_parser + ) + or 0 + for output in final_res.outputs + ) + usage = self._make_usage_info( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + num_cached_tokens=final_res.num_cached_tokens, + reasoning_tokens=reasoning_tokens, + ) +""", + ) + + _install_method("chat_completion_full_generator", method_source) + + +_patch_chat_completion_stream_generator() +_patch_chat_completion_full_generator()