Files
xc-llm-ascend/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py

250 lines
8.0 KiB
Python
Raw Permalink Normal View History

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# OpenAI chat usage accounting: backport MiniMax reasoning token accounting.
#
from __future__ import annotations
from collections.abc import AsyncIterator, Sequence
from dataclasses import dataclass
from typing import Any
from vllm.entrypoints.openai.chat_completion import protocol as chat_protocol
from vllm.entrypoints.openai.chat_completion import serving as chat_serving
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine import protocol as engine_protocol
from vllm.reasoning import minimax_m2_reasoning_parser as minimax_parser
class CompletionTokenUsageInfo(engine_protocol.OpenAIBaseModel):
reasoning_tokens: int | None = None
audio_tokens: int | None = None
accepted_prediction_tokens: int | None = None
rejected_prediction_tokens: int | None = None
class UsageInfo(engine_protocol.UsageInfo):
completion_tokens_details: CompletionTokenUsageInfo | None = None
CompletionTokenUsageInfo.__module__ = engine_protocol.__name__
UsageInfo.__module__ = engine_protocol.__name__
engine_protocol.CompletionTokenUsageInfo = CompletionTokenUsageInfo
engine_protocol.UsageInfo = UsageInfo
chat_protocol.UsageInfo = UsageInfo
chat_serving.CompletionTokenUsageInfo = CompletionTokenUsageInfo
chat_serving.UsageInfo = UsageInfo
def _rebuild_model_field(model_cls, field_name: str, annotation) -> None:
model_cls.__annotations__[field_name] = annotation
model_cls.model_fields[field_name].annotation = annotation
model_cls.model_rebuild(force=True)
_rebuild_model_field(chat_protocol.ChatCompletionResponse, "usage", UsageInfo)
_rebuild_model_field(chat_protocol.ChatCompletionStreamResponse, "usage", UsageInfo | None)
_rebuild_model_field(engine_protocol.RequestResponseMetadata, "final_usage_info", UsageInfo | None)
def _count_minimax_reasoning_tokens(
token_ids: Sequence[int],
end_token_id: int | None,
) -> int:
if end_token_id is None:
return 0
for idx, token_id in enumerate(token_ids):
if token_id == end_token_id:
return idx
return len(token_ids)
def _patched_count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
return _count_minimax_reasoning_tokens(token_ids, self.end_token_id)
minimax_parser.MiniMaxM2ReasoningParser.count_reasoning_tokens = _patched_count_reasoning_tokens
minimax_parser.MiniMaxM2AppendThinkReasoningParser.count_reasoning_tokens = _patched_count_reasoning_tokens
def _count_reasoning_tokens_for_usage(
token_ids: Sequence[int],
reasoning_parser,
) -> int | None:
if reasoning_parser is None:
return None
return reasoning_parser.count_reasoning_tokens(token_ids)
def _make_usage_info(
self,
*,
prompt_tokens: int,
completion_tokens: int,
num_cached_tokens: int | None = None,
reasoning_tokens: int | None = None,
) -> UsageInfo:
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if reasoning_tokens is not None:
usage.completion_tokens_details = CompletionTokenUsageInfo(
reasoning_tokens=max(0, min(reasoning_tokens, completion_tokens))
)
if self.enable_prompt_tokens_details and num_cached_tokens:
usage.prompt_tokens_details = chat_serving.PromptTokenUsageInfo(cached_tokens=num_cached_tokens)
return usage
OpenAIServingChat._count_reasoning_tokens_for_usage = staticmethod(_count_reasoning_tokens_for_usage)
OpenAIServingChat._make_usage_info = _make_usage_info
@dataclass
class _UsageTrackingState:
completion_tokens: list[int]
raw_output_token_ids: list[list[int]]
reasoning_parser: Any
num_prompt_tokens: int = 0
num_cached_tokens: int | None = None
final_res: Any = None
def _create_usage_tracking_state(
num_choices: int,
reasoning_parser,
) -> _UsageTrackingState:
return _UsageTrackingState(
completion_tokens=[0] * num_choices,
raw_output_token_ids=[[] for _ in range(num_choices)],
reasoning_parser=reasoning_parser,
)
def _update_usage_tracking_state(
state: _UsageTrackingState,
res,
) -> None:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids)
state.num_prompt_tokens = num_prompt_tokens
if state.num_cached_tokens is None:
state.num_cached_tokens = res.num_cached_tokens
state.final_res = res
for output in res.outputs:
if 0 <= output.index < len(state.completion_tokens):
token_ids = chat_serving.as_list(output.token_ids)
state.completion_tokens[output.index] += len(token_ids)
state.raw_output_token_ids[output.index].extend(token_ids)
async def _tracked_result_generator(
result_generator: AsyncIterator,
state: _UsageTrackingState,
):
async for res in result_generator:
_update_usage_tracking_state(state, res)
yield res
def _sum_reasoning_tokens_for_usage(
raw_output_token_ids: list[list[int]],
reasoning_parser,
) -> int | None:
if reasoning_parser is None:
return None
return sum(
_count_reasoning_tokens_for_usage(token_ids, reasoning_parser) or 0 for token_ids in raw_output_token_ids
)
def _make_full_response_usage(
self,
state: _UsageTrackingState,
) -> UsageInfo | None:
if state.final_res is None:
return None
return self._make_usage_info(
prompt_tokens=state.num_prompt_tokens,
completion_tokens=sum(state.completion_tokens),
num_cached_tokens=state.num_cached_tokens,
reasoning_tokens=_sum_reasoning_tokens_for_usage(
state.raw_output_token_ids,
state.reasoning_parser,
),
)
if not hasattr(OpenAIServingChat, "_ascend_original_chat_completion_full_generator"):
OpenAIServingChat._ascend_original_chat_completion_full_generator = OpenAIServingChat.chat_completion_full_generator
async def _wrapped_chat_completion_full_generator(
self,
request: chat_protocol.ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
model_name: str,
conversation,
tokenizer,
request_metadata: engine_protocol.RequestResponseMetadata,
reasoning_parser=None,
):
num_choices = 1 if request.n is None else request.n
state = _create_usage_tracking_state(num_choices, reasoning_parser)
original_full_generator = self._ascend_original_chat_completion_full_generator
response = await original_full_generator(
request,
_tracked_result_generator(result_generator, state),
request_id,
model_name,
conversation,
tokenizer,
request_metadata,
reasoning_parser,
)
if not isinstance(response, chat_protocol.ChatCompletionResponse):
return response
usage = _make_full_response_usage(self, state)
if usage is None:
return response
response.usage = usage
request_metadata.final_usage_info = usage
return response
_wrapped_chat_completion_full_generator.__module__ = OpenAIServingChat.__module__
_wrapped_chat_completion_full_generator.__qualname__ = (
f"{OpenAIServingChat.__qualname__}.chat_completion_full_generator"
)
OpenAIServingChat.chat_completion_full_generator = _wrapped_chat_completion_full_generator