diff --git a/tests/ut/patch/platform/test_patch_minimax_usage_accounting.py b/tests/ut/patch/platform/test_patch_minimax_usage_accounting.py index 3456687d..6950ff34 100644 --- a/tests/ut/patch/platform/test_patch_minimax_usage_accounting.py +++ b/tests/ut/patch/platform/test_patch_minimax_usage_accounting.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 -import pytest +from types import SimpleNamespace +import pytest from vllm.reasoning.minimax_m2_reasoning_parser import ( MiniMaxM2AppendThinkReasoningParser, MiniMaxM2ReasoningParser, ) +from vllm_ascend.patch.platform import ( + patch_minimax_usage_accounting as minimax_usage_patch, +) + class FakeTokenizer: def get_vocab(self): @@ -35,13 +40,13 @@ class FakeTokenizer: MiniMaxM2ReasoningParser, [10, 11, 20], 3, - id="minimax-all-tokens-are-reasoning-before-end-token", + id="minimax-no-end-token-means-all-output-is-reasoning", ), pytest.param( MiniMaxM2AppendThinkReasoningParser, [10, 11, 20], 3, - id="append-think-all-tokens-are-reasoning-before-end-token", + id="append-think-no-end-token-means-all-output-is-reasoning", ), pytest.param( MiniMaxM2ReasoningParser, @@ -65,3 +70,73 @@ def test_count_reasoning_tokens( parser = parser_cls(FakeTokenizer()) assert parser.count_reasoning_tokens(token_ids) == expected_reasoning_tokens + + +def test_update_usage_tracking_state_tracks_prompt_and_completion_tokens(): + state = minimax_usage_patch._create_usage_tracking_state( + num_choices=2, + reasoning_parser=None, + ) + + res = SimpleNamespace( + prompt_token_ids=[1, 2], + encoder_prompt_token_ids=[3], + num_cached_tokens=4, + outputs=[ + SimpleNamespace(index=0, token_ids=(10, 11)), + SimpleNamespace(index=1, token_ids=[20]), + ], + ) + + minimax_usage_patch._update_usage_tracking_state(state, res) + + assert state.num_prompt_tokens == 3 + assert state.num_cached_tokens == 4 + assert state.completion_tokens == [2, 1] + assert state.raw_output_token_ids == [[10, 11], [20]] + + +def test_make_usage_info_injects_reasoning_token_details(): + fake_serving = SimpleNamespace(enable_prompt_tokens_details=True) + usage = minimax_usage_patch._make_usage_info( + fake_serving, + prompt_tokens=3, + completion_tokens=4, + num_cached_tokens=1, + reasoning_tokens=2, + ) + + payload = usage.model_dump(exclude_none=True) + + assert payload["completion_tokens_details"]["reasoning_tokens"] == 2 + assert payload["prompt_tokens_details"]["cached_tokens"] == 1 + + +def test_make_full_response_usage_sums_reasoning_tokens(): + class FakeServing: + enable_prompt_tokens_details = False + + def _make_usage_info(self, **kwargs): + return minimax_usage_patch._make_usage_info(self, **kwargs) + + class FakeReasoningParser: + def count_reasoning_tokens(self, token_ids): + return 2 if 2 in token_ids else 0 + + state = minimax_usage_patch._create_usage_tracking_state( + num_choices=2, + reasoning_parser=FakeReasoningParser(), + ) + state.num_prompt_tokens = 3 + state.num_cached_tokens = 1 + state.final_res = SimpleNamespace(num_cached_tokens=1) + state.completion_tokens = [4, 2] + state.raw_output_token_ids = [[10, 11, 2, 20], [30, 31]] + + usage = minimax_usage_patch._make_full_response_usage(FakeServing(), state) + + assert usage.prompt_tokens == 3 + assert usage.completion_tokens == 6 + assert usage.total_tokens == 9 + assert usage.completion_tokens_details.reasoning_tokens == 2 + assert usage.prompt_tokens_details is None diff --git a/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py b/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py index 6c416cb3..ace83e4c 100644 --- a/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py +++ b/vllm_ascend/patch/platform/patch_minimax_usage_accounting.py @@ -19,10 +19,8 @@ from __future__ import annotations -import ast -import textwrap -from collections.abc import Sequence -from pathlib import Path +from collections.abc import AsyncIterator, Sequence +from dataclasses import dataclass from typing import Any from vllm.entrypoints.openai.chat_completion import protocol as chat_protocol @@ -32,46 +30,6 @@ from vllm.entrypoints.openai.engine import protocol as engine_protocol from vllm.reasoning import minimax_m2_reasoning_parser as minimax_parser -def _extract_class_method_source( - module_path: str, - class_name: str, - method_name: str, -) -> str: - source = Path(module_path).read_text(encoding="utf-8") - tree = ast.parse(source) - for node in tree.body: - if isinstance(node, ast.ClassDef) and node.name == class_name: - for item in node.body: - if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == method_name: - method_source = ast.get_source_segment(source, item) - if method_source is None: - break - return textwrap.dedent(method_source) - - raise RuntimeError(f"Unable to extract {class_name}.{method_name} from {module_path}.") - - -def _install_method(method_name: str, method_source: str) -> None: - namespace: dict[str, Any] = {} - exec(method_source, chat_serving.__dict__, namespace) - method = namespace[method_name] - method.__module__ = OpenAIServingChat.__module__ - method.__qualname__ = f"{OpenAIServingChat.__qualname__}.{method_name}" - setattr(OpenAIServingChat, method_name, method) - - -def _replace_block( - source: str, - old: str, - new: str, - *, - count: int = 1, -) -> str: - if source.count(old) < count: - raise RuntimeError("Failed to locate expected block while patching OpenAIServingChat usage accounting.") - return source.replace(old, new, count) - - class CompletionTokenUsageInfo(engine_protocol.OpenAIBaseModel): reasoning_tokens: int | None = None audio_tokens: int | None = None @@ -160,208 +118,132 @@ OpenAIServingChat._count_reasoning_tokens_for_usage = staticmethod(_count_reason OpenAIServingChat._make_usage_info = _make_usage_info -def _patch_chat_completion_stream_generator() -> None: - method_source = _extract_class_method_source( - chat_serving.__file__, - "OpenAIServingChat", - "chat_completion_stream_generator", +@dataclass +class _UsageTrackingState: + completion_tokens: list[int] + raw_output_token_ids: list[list[int]] + reasoning_parser: Any + num_prompt_tokens: int = 0 + num_cached_tokens: int | None = None + final_res: Any = None + + +def _create_usage_tracking_state( + num_choices: int, + reasoning_parser, +) -> _UsageTrackingState: + return _UsageTrackingState( + completion_tokens=[0] * num_choices, + raw_output_token_ids=[[] for _ in range(num_choices)], + reasoning_parser=reasoning_parser, ) - method_source = _replace_block( - method_source, - """\ - previous_num_tokens = [0] * num_choices - finish_reason_sent = [False] * num_choices -""", - """\ - previous_num_tokens = [0] * num_choices - raw_output_token_ids = [[] for _ in range(num_choices)] - finish_reason_sent = [False] * num_choices -""", + +def _update_usage_tracking_state( + state: _UsageTrackingState, + res, +) -> None: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + state.num_prompt_tokens = num_prompt_tokens + + if state.num_cached_tokens is None: + state.num_cached_tokens = res.num_cached_tokens + + state.final_res = res + + for output in res.outputs: + if 0 <= output.index < len(state.completion_tokens): + token_ids = chat_serving.as_list(output.token_ids) + state.completion_tokens[output.index] += len(token_ids) + state.raw_output_token_ids[output.index].extend(token_ids) + + +async def _tracked_result_generator( + result_generator: AsyncIterator, + state: _UsageTrackingState, +): + async for res in result_generator: + _update_usage_tracking_state(state, res) + yield res + + +def _sum_reasoning_tokens_for_usage( + raw_output_token_ids: list[list[int]], + reasoning_parser, +) -> int | None: + if reasoning_parser is None: + return None + return sum( + _count_reasoning_tokens_for_usage(token_ids, reasoning_parser) or 0 for token_ids in raw_output_token_ids ) - method_source = _replace_block( - method_source, - """\ - if include_continuous_usage: - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - total_tokens=num_prompt_tokens, - ) -""", - """\ - if include_continuous_usage: - chunk.usage = self._make_usage_info( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - reasoning_tokens=self._count_reasoning_tokens_for_usage( - raw_output_token_ids[i], reasoning_parser - ), - ) -""", + +def _make_full_response_usage( + self, + state: _UsageTrackingState, +) -> UsageInfo | None: + if state.final_res is None: + return None + + return self._make_usage_info( + prompt_tokens=state.num_prompt_tokens, + completion_tokens=sum(state.completion_tokens), + num_cached_tokens=state.num_cached_tokens, + reasoning_tokens=_sum_reasoning_tokens_for_usage( + state.raw_output_token_ids, + state.reasoning_parser, + ), ) - method_source = _replace_block( - method_source, - """\ - if include_continuous_usage: - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - total_tokens=num_prompt_tokens, - ) -""", - """\ - if include_continuous_usage: - chunk.usage = self._make_usage_info( - prompt_tokens=num_prompt_tokens, - completion_tokens=0, - reasoning_tokens=self._count_reasoning_tokens_for_usage( - raw_output_token_ids[i], reasoning_parser - ), - ) -""", + +if not hasattr(OpenAIServingChat, "_ascend_original_chat_completion_full_generator"): + OpenAIServingChat._ascend_original_chat_completion_full_generator = OpenAIServingChat.chat_completion_full_generator + + +async def _wrapped_chat_completion_full_generator( + self, + request: chat_protocol.ChatCompletionRequest, + result_generator: AsyncIterator, + request_id: str, + model_name: str, + conversation, + tokenizer, + request_metadata: engine_protocol.RequestResponseMetadata, + reasoning_parser=None, +): + num_choices = 1 if request.n is None else request.n + state = _create_usage_tracking_state(num_choices, reasoning_parser) + + original_full_generator = self._ascend_original_chat_completion_full_generator + response = await original_full_generator( + request, + _tracked_result_generator(result_generator, state), + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + reasoning_parser, ) - method_source = _replace_block( - method_source, - """\ - previous_num_tokens[i] += len(output.token_ids) -""", - """\ - previous_num_tokens[i] += len(output.token_ids) - raw_output_token_ids[i].extend(as_list(output.token_ids)) -""", - ) + if not isinstance(response, chat_protocol.ChatCompletionResponse): + return response - method_source = _replace_block( - method_source, - """\ - if include_continuous_usage: - completion_tokens = previous_num_tokens[i] - chunk.usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens, - ) -""", - """\ - if include_continuous_usage: - completion_tokens = previous_num_tokens[i] - chunk.usage = self._make_usage_info( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - reasoning_tokens=self._count_reasoning_tokens_for_usage( - raw_output_token_ids[i], reasoning_parser - ), - ) -""", - ) + usage = _make_full_response_usage(self, state) + if usage is None: + return response - method_source = _replace_block( - method_source, - """\ - final_usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=num_prompt_tokens + completion_tokens, - ) - if self.enable_prompt_tokens_details and num_cached_tokens: - final_usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=num_cached_tokens - ) -""", - """\ - reasoning_tokens = None - if reasoning_parser is not None: - reasoning_tokens = sum( - self._count_reasoning_tokens_for_usage( - token_ids, reasoning_parser - ) - or 0 - for token_ids in raw_output_token_ids - ) - final_usage = self._make_usage_info( - prompt_tokens=num_prompt_tokens, - completion_tokens=completion_tokens, - num_cached_tokens=num_cached_tokens, - reasoning_tokens=reasoning_tokens, - ) -""", - ) - - method_source = _replace_block( - method_source, - """\ - request_metadata.final_usage_info = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_completion_tokens, - total_tokens=num_prompt_tokens + num_completion_tokens, - ) -""", - """\ - reasoning_tokens = None - if reasoning_parser is not None: - reasoning_tokens = sum( - self._count_reasoning_tokens_for_usage( - token_ids, reasoning_parser - ) - or 0 - for token_ids in raw_output_token_ids - ) - request_metadata.final_usage_info = self._make_usage_info( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_completion_tokens, - reasoning_tokens=reasoning_tokens, - ) -""", - ) - - _install_method("chat_completion_stream_generator", method_source) + response.usage = usage + request_metadata.final_usage_info = usage + return response -def _patch_chat_completion_full_generator() -> None: - method_source = _extract_class_method_source( - chat_serving.__file__, - "OpenAIServingChat", - "chat_completion_full_generator", - ) +_wrapped_chat_completion_full_generator.__module__ = OpenAIServingChat.__module__ +_wrapped_chat_completion_full_generator.__qualname__ = ( + f"{OpenAIServingChat.__qualname__}.chat_completion_full_generator" +) - method_source = _replace_block( - method_source, - """\ - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - if self.enable_prompt_tokens_details and final_res.num_cached_tokens: - usage.prompt_tokens_details = PromptTokenUsageInfo( - cached_tokens=final_res.num_cached_tokens - ) -""", - """\ - reasoning_tokens = None - if reasoning_parser is not None: - reasoning_tokens = sum( - self._count_reasoning_tokens_for_usage( - as_list(output.token_ids), reasoning_parser - ) - or 0 - for output in final_res.outputs - ) - usage = self._make_usage_info( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - num_cached_tokens=final_res.num_cached_tokens, - reasoning_tokens=reasoning_tokens, - ) -""", - ) - - _install_method("chat_completion_full_generator", method_source) - - -_patch_chat_completion_stream_generator() -_patch_chat_completion_full_generator() +OpenAIServingChat.chat_completion_full_generator = _wrapped_chat_completion_full_generator