[Fix] Refactor dummy attention metadata creation (#3497)

### What this PR does / why we need it?
The `force_attention` parameter is designed for flash infer kernel
warmup, we don't actually need it on Ascend device (at least for
now).And it tends to make things more complicated. So we replace the
`force_attention` parameter with `aclgraph_runtime_mode` in the
attention metadata creation logic.

This change makes the control flow more explicit by directly using the
graph runtime mode to determine how to build attention metadata, rather
than relying on an intermediate boolean flag. This simplification
removes redundant logic and clarifies the conditions for building
attention metadata for full decode graph mode.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
DP + `FULL_DECODE_ONLY` + online serving.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-10-21 00:00:42 +08:00
committed by GitHub
parent 6b6857929d
commit 274b708e0c
4 changed files with 35 additions and 24 deletions

View File

@@ -456,9 +456,7 @@ class TestNPUWorker(TestBase):
# Verify call # Verify call
mock_model_runner._dummy_run.assert_called_once_with( mock_model_runner._dummy_run.assert_called_once_with(
num_tokens=mock_decode_token_per_req, num_tokens=mock_decode_token_per_req, uniform_decode=True)
uniform_decode=True,
force_attention=False)
@patch("vllm_ascend.worker.worker_v1.envs_vllm") @patch("vllm_ascend.worker.worker_v1.envs_vllm")
@patch("vllm_ascend.worker.worker_v1.logger") @patch("vllm_ascend.worker.worker_v1.logger")

View File

@@ -19,13 +19,13 @@
import math import math
import types import types
from typing import Optional from typing import Any, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch_npu import torch_npu
from vllm.config import VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
@@ -147,14 +147,21 @@ class NPUTorchairModelRunner(NPUModelRunner):
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
def _build_attention_metadata(self, with_prefill, num_reqs, num_tokens, def _build_dummy_attn_metadata(
max_query_len, force_attention): self,
with_prefill: bool,
num_reqs: int,
num_tokens: int,
max_query_len: int,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
# NOTE: If torchair graph mode and not with_prefill, # NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile. # we can't skip_attn, it will cause graph recompile.
if with_prefill or self.enable_shared_expert_dp: if with_prefill or self.enable_shared_expert_dp:
attn_metadata = super()._build_attention_metadata( attn_metadata = super()._build_dummy_attn_metadata(
with_prefill, num_reqs, num_tokens, max_query_len, with_prefill, num_reqs, num_tokens, max_query_len,
force_attention) aclgraph_runtime_mode, force_attention)
else: else:
common_attn_metadata = TorchairCommonAttentionMetadata( common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=num_reqs, num_reqs=num_reqs,

View File

@@ -2250,18 +2250,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
scheduler_output.finished_req_ids) scheduler_output.finished_req_ids)
return None, None return None, None
def _build_attention_metadata(self, create_mixed_batch, num_reqs, def _build_dummy_attn_metadata(
num_tokens, max_query_len, force_attention): self,
with_prefill: bool,
num_reqs: int,
num_tokens: int,
max_query_len: int,
aclgraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False,
) -> Optional[dict[str, Any]]:
attn_metadata: Optional[dict[str, Any]] = None attn_metadata: Optional[dict[str, Any]] = None
if force_attention: if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL:
assert with_prefill is False, \
"Full decode graph only supports uniform batch now."
attn_metadata = {} attn_metadata = {}
if create_mixed_batch: seq_lens = self.model_config.max_model_len
raise NotImplementedError(
"force_attention=True is not supported for mixed batches.")
else:
seq_lens = self.model_config.max_model_len
self.seq_lens_np[:num_reqs] = seq_lens self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0 self.seq_lens_np[num_reqs:] = 0
@@ -2321,7 +2327,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
forward_context = get_forward_context() forward_context = get_forward_context()
assert forward_context is not None assert forward_context is not None
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
not forward_context.capturing: not forward_context.capturing and forward_context.attn_metadata is not None:
if self.vllm_config.model_config.use_mla: if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True` # FIXME: Try using `auto_dispatch_capture=True`
update_mla_attn_params(self.update_stream, forward_context, update_mla_attn_params(self.update_stream, forward_context,
@@ -2409,12 +2415,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer and not self.is_kv_consumer: if self.is_kv_producer and not self.is_kv_consumer:
with_prefill = True with_prefill = True
# TODO(cmq): check if with_prefill is reasonable # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup
attn_metadata = self._build_attention_metadata( # and not supported in ASCEND now. We could remove it in the future.
attn_metadata = self._build_dummy_attn_metadata(
False, False,
num_reqs=num_reqs, num_reqs=num_reqs,
num_tokens=num_tokens, num_tokens=num_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
aclgraph_runtime_mode=aclgraph_runtime_mode,
force_attention=force_attention, force_attention=force_attention,
) )

View File

@@ -26,7 +26,7 @@ import torch_npu
import vllm.envs as envs_vllm import vllm.envs as envs_vllm
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from torch_npu.profiler import dynamic_profile as dp from torch_npu.profiler import dynamic_profile as dp
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
@@ -360,11 +360,9 @@ class NPUWorker(WorkerBase):
return self.model_runner.pin_lora(lora_id) return self.model_runner.pin_lora(lora_id)
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
self.model_runner._dummy_run( self.model_runner._dummy_run(
num_tokens=self.model_runner.decode_token_per_req, num_tokens=self.model_runner.decode_token_per_req,
uniform_decode=True, uniform_decode=True)
force_attention=force_attention)
def _init_worker_distributed_environment(self) -> None: def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""