From 1b4ce63ec9cbdcdd2e7155466bab935fd8625fb2 Mon Sep 17 00:00:00 2001 From: XiaoxinWang <963372609@qq.com> Date: Wed, 12 Nov 2025 10:11:43 +0800 Subject: [PATCH] fix fullgraph in ds. (#4016) ### What this PR does / why we need it? DS don't have 'AscendAttentionMetadataBuilder' class so will fail in fullgraph. We resolved the issue by modifying the code to only check for 'GDNAttentionMetadataBuilder ', while all other attention cases follow the default branch. - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac Signed-off-by: wangxiaoxin-sherie Co-authored-by: wangxiaoxin-sherie --- tests/e2e/multicard/test_external_launcher.py | 4 ---- tests/e2e/singlecard/test_aclgraph.py | 11 +++++++---- vllm_ascend/worker/model_runner_v1.py | 11 +++++------ 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/tests/e2e/multicard/test_external_launcher.py b/tests/e2e/multicard/test_external_launcher.py index d5441691..05851db1 100644 --- a/tests/e2e/multicard/test_external_launcher.py +++ b/tests/e2e/multicard/test_external_launcher.py @@ -148,8 +148,6 @@ def test_external_launcher_and_sleepmode(): print(output) - assert "TP RANKS: [0]" in output - assert "TP RANKS: [1]" in output assert "Generated text:" in output assert "Sleep and wake up successfully!!" in output assert proc.returncode == 0 @@ -198,8 +196,6 @@ def test_external_launcher_and_sleepmode_level2(): print(output) - assert "TP RANKS: [0]" in output - assert "TP RANKS: [1]" in output assert "Generated text:" in output assert "Sleep and wake up successfully!!" in output assert proc.returncode == 0 diff --git a/tests/e2e/singlecard/test_aclgraph.py b/tests/e2e/singlecard/test_aclgraph.py index efa6cb39..86ce4795 100644 --- a/tests/e2e/singlecard/test_aclgraph.py +++ b/tests/e2e/singlecard/test_aclgraph.py @@ -100,8 +100,6 @@ def test_models_with_aclgraph( ) -@pytest.mark.skip("Skipping this test for now, " - "it fails intermittently and needs investigation.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [5]) def test_models_with_aclgraph_full_decode_only( @@ -172,7 +170,10 @@ def test_models_with_aclgraph_full_decode_only( model, max_model_len=1024, enforce_eager=False, - compilation_config={"cudagraph_mode": "FULL_DECODE_ONLY"}, + compilation_config={ + "cudagraph_capture_sizes": [4, 8, 32, 64], + "cudagraph_mode": "FULL_DECODE_ONLY" + }, ) as runner: vllm_aclgraph_outputs = runner.model.generate( prompts, sampling_params) @@ -180,7 +181,9 @@ def test_models_with_aclgraph_full_decode_only( with VllmRunner( model, max_model_len=1024, - enforce_eager=True, + compilation_config={ + "cudagraph_capture_sizes": [4, 8, 32, 64], + }, ) as runner: vllm_eager_outputs = runner.model.generate(prompts, sampling_params) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 167345bf..7f8fe1e1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -108,8 +108,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import (MoECommType, set_ascend_forward_context) from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import (AscendAttentionMetadataBuilder, - AscendAttentionState) +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, AscendPrefillContextParallelMetadata) # yapf: disable @@ -2887,12 +2886,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): for attn_group in self.attn_groups[kv_cache_group_id]: builder = attn_group.get_metadata_builder() - if isinstance(builder, AscendAttentionMetadataBuilder): - attn_metadata_full_attention = builder.build_for_graph_capture( - common_attn_metadata, attn_state, self.get_model()) - elif isinstance(builder, GDNAttentionMetadataBuilder): + if isinstance(builder, GDNAttentionMetadataBuilder): attn_metadata_gdn_attention = builder.build_for_cudagraph_capture( common_metadata) + else: + attn_metadata_full_attention = builder.build_for_graph_capture( + common_attn_metadata, attn_state, self.get_model()) for layer_name in kv_cache_group_spec.layer_names: if "linear_attn" in layer_name: attn_metadata[