diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index 8d55a94..2313e71 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -456,9 +456,7 @@ class TestNPUWorker(TestBase): # Verify call mock_model_runner._dummy_run.assert_called_once_with( - num_tokens=mock_decode_token_per_req, - uniform_decode=True, - force_attention=False) + num_tokens=mock_decode_token_per_req, uniform_decode=True) @patch("vllm_ascend.worker.worker_v1.envs_vllm") @patch("vllm_ascend.worker.worker_v1.logger") diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 2a5c513..fbdb42b 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -19,13 +19,13 @@ import math import types -from typing import Optional +from typing import Any, Optional import torch import torch.distributed as dist import torch.nn as nn import torch_npu -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context @@ -147,14 +147,21 @@ class NPUTorchairModelRunner(NPUModelRunner): return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo - def _build_attention_metadata(self, with_prefill, num_reqs, num_tokens, - max_query_len, force_attention): + def _build_dummy_attn_metadata( + self, + with_prefill: bool, + num_reqs: int, + num_tokens: int, + max_query_len: int, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + ) -> Optional[dict[str, Any]]: # NOTE: If torchair graph mode and not with_prefill, # we can't skip_attn, it will cause graph recompile. if with_prefill or self.enable_shared_expert_dp: - attn_metadata = super()._build_attention_metadata( + attn_metadata = super()._build_dummy_attn_metadata( with_prefill, num_reqs, num_tokens, max_query_len, - force_attention) + aclgraph_runtime_mode, force_attention) else: common_attn_metadata = TorchairCommonAttentionMetadata( num_reqs=num_reqs, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c23ec03..784062c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2250,18 +2250,24 @@ class NPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None - def _build_attention_metadata(self, create_mixed_batch, num_reqs, - num_tokens, max_query_len, force_attention): + def _build_dummy_attn_metadata( + self, + with_prefill: bool, + num_reqs: int, + num_tokens: int, + max_query_len: int, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, + force_attention: bool = False, + ) -> Optional[dict[str, Any]]: attn_metadata: Optional[dict[str, Any]] = None - if force_attention: + if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL: + assert with_prefill is False, \ + "Full decode graph only supports uniform batch now." + attn_metadata = {} - if create_mixed_batch: - raise NotImplementedError( - "force_attention=True is not supported for mixed batches.") - else: - seq_lens = self.model_config.max_model_len + seq_lens = self.model_config.max_model_len self.seq_lens_np[:num_reqs] = seq_lens self.seq_lens_np[num_reqs:] = 0 @@ -2321,7 +2327,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): forward_context = get_forward_context() assert forward_context is not None if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ - not forward_context.capturing: + not forward_context.capturing and forward_context.attn_metadata is not None: if self.vllm_config.model_config.use_mla: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, @@ -2409,12 +2415,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.is_kv_producer and not self.is_kv_consumer: with_prefill = True - # TODO(cmq): check if with_prefill is reasonable - attn_metadata = self._build_attention_metadata( + # TODO(Mengqing): Set create_mixed_batch to False since it's only used in FI warmup + # and not supported in ASCEND now. We could remove it in the future. + attn_metadata = self._build_dummy_attn_metadata( False, num_reqs=num_reqs, num_tokens=num_tokens, max_query_len=max_query_len, + aclgraph_runtime_mode=aclgraph_runtime_mode, force_attention=force_attention, ) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index d26c077..d8be9c2 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -26,7 +26,7 @@ import torch_npu import vllm.envs as envs_vllm from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from torch_npu.profiler import dynamic_profile as dp -from vllm.config import CUDAGraphMode, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized @@ -360,11 +360,9 @@ class NPUWorker(WorkerBase): return self.model_runner.pin_lora(lora_id) def execute_dummy_batch(self) -> None: - force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY self.model_runner._dummy_run( num_tokens=self.model_runner.decode_token_per_req, - uniform_decode=True, - force_attention=force_attention) + uniform_decode=True) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment."""