From 274b708e0c025dfc3b6e3ac6891f7efb8aa337ca Mon Sep 17 00:00:00 2001 From: Yizhou <136800916+yiz-liu@users.noreply.github.com> Date: Tue, 21 Oct 2025 00:00:42 +0800 Subject: [PATCH] [Fix] Refactor dummy attention metadata creation (#3497) ### What this PR does / why we need it? The `force_attention` parameter is designed for flash infer kernel warmup, we don't actually need it on Ascend device (at least for now).And it tends to make things more complicated. So we replace the `force_attention` parameter with `aclgraph_runtime_mode` in the attention metadata creation logic. This change makes the control flow more explicit by directly using the graph runtime mode to determine how to build attention metadata, rather than relying on an intermediate boolean flag. This simplification removes redundant logic and clarifies the conditions for building attention metadata for full decode graph mode. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? DP + `FULL_DECODE_ONLY` + online serving. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Yizhou Liu --- tests/ut/worker/test_worker_v1.py | 4 +-- vllm_ascend/torchair/torchair_model_runner.py | 19 ++++++++---- vllm_ascend/worker/model_runner_v1.py | 30 ++++++++++++------- vllm_ascend/worker/worker_v1.py | 6 ++-- 4 files changed, 35 insertions(+), 24 deletions(-) diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index 8d55a94..2313e71 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -456,9 +456,7 @@ class TestNPUWorker(TestBase): # Verify call mock_model_runner._dummy_run.assert_called_once_with( - num_tokens=mock_decode_token_per_req, - uniform_decode=True, - force_attention=False) + num_tokens=mock_decode_token_per_req, uniform_decode=True) @patch("vllm_ascend.worker.worker_v1.envs_vllm") @patch("vllm_ascend.worker.worker_v1.logger") diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 2a5c513..fbdb42b 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -19,13 +19,13 @@ import math import types -from typing import Optional +from typing import Any, Optional import torch import torch.distributed as dist import torch.nn as nn import torch_npu -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context @@ -147,14 +147,21 @@ class NPUTorchairModelRunner(NPUModelRunner): return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo - def _build_attention_metadata(self, with_prefill, num_reqs, num_tokens, - max_query_len, force_attention): + def _build_dummy_attn_metadata( + self, + with_prefill: bool, + num_reqs: int, + num_tokens: int, + max_query_len: int, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + ) -> Optional[dict[str, Any]]: # NOTE: If torchair graph mode and not with_prefill, # we can't skip_attn, it will cause graph recompile. if with_prefill or self.enable_shared_expert_dp: - attn_metadata = super()._build_attention_metadata( + attn_metadata = super()._build_dummy_attn_metadata( with_prefill, num_reqs, num_tokens, max_query_len, - force_attention) + aclgraph_runtime_mode, force_attention) else: common_attn_metadata = TorchairCommonAttentionMetadata( num_reqs=num_reqs, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c23ec03..784062c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2250,18 +2250,24 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None - def _build_attention_metadata(self, create_mixed_batch, num_reqs, - num_tokens, max_query_len, force_attention): + def _build_dummy_attn_metadata( + self, + with_prefill: bool, + num_reqs: int, + num_tokens: int, + max_query_len: int, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + ) -> Optional[dict[str, Any]]: attn_metadata: Optional[dict[str, Any]] = None - if force_attention: + if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL: + assert with_prefill is False, \ + "Full decode graph only supports uniform batch now." + attn_metadata = {} - if create_mixed_batch: - raise NotImplementedError( - "force_attention=True is not supported for mixed batches.") - else: - seq_lens = self.model_config.max_model_len + seq_lens = self.model_config.max_model_len self.seq_lens_np[:num_reqs] = seq_lens self.seq_lens_np[num_reqs:] = 0 @@ -2321,7 +2327,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): forward_context = get_forward_context() assert forward_context is not None if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ - not forward_context.capturing: + not forward_context.capturing and forward_context.attn_metadata is not None: if self.vllm_config.model_config.use_mla: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, @@ -2409,12 +2415,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.is_kv_producer and not self.is_kv_consumer: with_prefill = True - # TODO(cmq): check if with_prefill is reasonable - attn_metadata = self._build_attention_metadata( + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( False, num_reqs=num_reqs, num_tokens=num_tokens, max_query_len=max_query_len, + aclgraph_runtime_mode=aclgraph_runtime_mode, force_attention=force_attention, ) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index d26c077..d8be9c2 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -26,7 +26,7 @@ import torch_npu import vllm.envs as envs_vllm from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from torch_npu.profiler import dynamic_profile as dp -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized @@ -360,11 +360,9 @@ class NPUWorker(WorkerBase): return self.model_runner.pin_lora(lora_id) def execute_dummy_batch(self) -> None: - force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY self.model_runner._dummy_run( num_tokens=self.model_runner.decode_token_per_req, - uniform_decode=True, - force_attention=force_attention) + uniform_decode=True) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment."""