From f8c93d8d24d9fde488625d4dc7ec9e77124a661b Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Tue, 30 Sep 2025 11:14:51 +0800 Subject: [PATCH] [Aclgraph][DP] Fix dp dummy run not in aclgraph error (#3208) ### What this PR does / why we need it? When running DP in a non-equilibrium scenario, which means there is some dp groups executing `dummy_run`, we need to make sure it running the same mode as other dp, thus improving then performance in dp scenario ### How was this patch tested? Tested by adding log in `_dummy_run` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: MengqingCao --- tests/ut/worker/test_worker_v1.py | 5 ++- vllm_ascend/worker/model_runner_v1.py | 48 ++++++++++++++++----------- vllm_ascend/worker/worker_v1.py | 7 ++-- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index eb05a7a..f4551de 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -444,6 +444,8 @@ class TestNPUWorker(TestBase): # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): worker = NPUWorker() + worker.compilation_config = MagicMock() + worker.compilation_config.cudagraph_mode = MagicMock() mock_model_runner = MagicMock() worker.model_runner = mock_model_runner @@ -451,7 +453,8 @@ class TestNPUWorker(TestBase): worker.execute_dummy_batch() # Verify call - mock_model_runner._dummy_run.assert_called_once_with(1) + mock_model_runner._dummy_run.assert_called_once_with( + num_tokens=1, uniform_decode=True, force_attention=False) @patch("vllm_ascend.worker.worker_v1.envs_vllm") @patch("vllm_ascend.worker.worker_v1.logger") diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4083ec4..08484fe 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2321,6 +2321,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds) + forward_context = get_forward_context() + assert forward_context is not None + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) + if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, _ = hidden_states else: @@ -2333,12 +2339,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_tokens: int, with_prefill: bool = False, is_torchair_compile: bool = False, - aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + aclgraph_runtime_mode: Optional[CUDAGraphMode] = None, force_attention: bool = False, uniform_decode: bool = False, ) -> torch.Tensor: # only support eager mode and piecewise graph now - assert aclgraph_runtime_mode in { + assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in { CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } @@ -2371,8 +2377,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): max_num_reqs = self.scheduler_config.max_num_seqs if uniform_decode: num_reqs = cdiv(num_tokens, max_query_len) - assert num_reqs <= max_num_reqs, \ - "Do not capture num_reqs > max_num_reqs for uniform batch" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len @@ -2395,12 +2399,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.is_kv_producer and not self.is_kv_consumer: with_prefill = True + # TODO(cmq): check if with_prefill is reasonable attn_metadata = self._build_attention_metadata( - with_prefill, - num_reqs, - num_tokens, - max_query_len, - force_attention, + False, + num_reqs=num_reqs, + num_tokens=num_tokens, + max_query_len=max_query_len, + force_attention=force_attention, ) if not self.in_profile_run and self.dynamic_eplb: @@ -2433,18 +2438,21 @@ class NPUModelRunner(LoRAModelRunnerMixin): k: v[:num_tokens] for k, v in self.intermediate_tensors.items() }) - if aclgraph_runtime_mode == CUDAGraphMode.NONE: - batch_descriptor = None - else: - # filter out the valid batch descriptor - _cg_mode, batch_descriptor = \ - self.aclgraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) - # sanity check - assert aclgraph_runtime_mode == _cg_mode, ( + + # filter out the valid batch descriptor + _ag_mode, batch_descriptor = \ + self.aclgraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) + if aclgraph_runtime_mode is not None: + # we allow forcing NONE when the dispatcher disagrees to support + # warm ups for aclgraph capture + assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \ + aclgraph_runtime_mode == _ag_mode, ( f"Aclgraph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {aclgraph_runtime_mode}.") + f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.") + else: + aclgraph_runtime_mode = _ag_mode need_dummy_logits = (not self.in_profile_run and lmhead_tp_enable()) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index dc82ece..f1acda3 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -26,7 +26,7 @@ import torch_npu import vllm.envs as envs_vllm from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions from torch_npu.profiler import dynamic_profile as dp -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized @@ -356,7 +356,10 @@ class NPUWorker(WorkerBase): return self.model_runner.pin_lora(lora_id) def execute_dummy_batch(self) -> None: - self.model_runner._dummy_run(1) + force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY + self.model_runner._dummy_run(num_tokens=1, + uniform_decode=True, + force_attention=force_attention) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment."""