From 248ee7fa11337c9535412d3d13b146e7e8ecab75 Mon Sep 17 00:00:00 2001 From: anon189Ty Date: Fri, 17 Oct 2025 20:19:56 +0800 Subject: [PATCH] [Feat]Make full graph mode compalible with MTP (#3276) ### What this PR does / why we need it? Make the Full Graph mode can run with MTP. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty --- .../spec_decode_v1/test_v1_mtp_correctness.py | 26 ++++++++++- .../test_v1_mtp_torchair_correctness.py | 25 ++++++++++- tests/ut/worker/test_worker_v1.py | 5 ++- vllm_ascend/attention/mla_v1.py | 21 +++++---- vllm_ascend/compilation/acl_graph.py | 19 ++++++-- vllm_ascend/worker/model_runner_v1.py | 44 +++++++++---------- vllm_ascend/worker/worker_v1.py | 7 +-- 7 files changed, 103 insertions(+), 44 deletions(-) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index 89d636a..b6d8b66 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -2,6 +2,7 @@ from __future__ import annotations import pytest from vllm import SamplingParams +from vllm.config import CompilationConfig, CUDAGraphMode from tests.e2e.conftest import VllmRunner @@ -20,6 +21,7 @@ def mtp_correctness( sampling_config: SamplingParams, model_name: str, num_speculative_tokens: int, + graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE, ): example_prompts = [ "Hello, my name is", @@ -38,6 +40,10 @@ def mtp_correctness( enforce_eager=False) as ref_llm: ref_outputs = ref_llm.generate(example_prompts, sampling_config) + graph_mode_str = "PIECEWISE" + if graph_mode == CUDAGraphMode.FULL: + graph_mode_str = "FULL" + with VllmRunner( model_name, tensor_parallel_size=1, @@ -51,6 +57,8 @@ def mtp_correctness( }, enforce_eager=False, max_model_len=2000, + compilation_config=CompilationConfig( + cudagraph_mode=graph_mode_str), additional_config={"ascend_scheduler_config": { "enabled": False }}) as spec_llm: @@ -74,15 +82,29 @@ def mtp_correctness( del spec_llm -def test_mtp1_correctness( +def test_mtp1_correctness_piecewise_graph( sampling_config: SamplingParams, model_name: str, ): mtp_correctness(sampling_config, model_name, 1) -def test_mtp2_correctness( +def test_mtp2_correctness_piecewise_graph( sampling_config: SamplingParams, model_name: str, ): mtp_correctness(sampling_config, model_name, 2) + + +def test_mtp1_correctness_full_graph( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL) + + +def test_mtp2_correctness_full_graph( + sampling_config: SamplingParams, + model_name: str, +): + mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL) diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py index 64816ac..d509671 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_torchair_correctness.py @@ -2,6 +2,7 @@ from __future__ import annotations import pytest from vllm import SamplingParams +from vllm.config import CompilationConfig, CUDAGraphMode from tests.e2e.conftest import VllmRunner @@ -16,9 +17,10 @@ def model_name(): return "wemaster/deepseek_mtp_main_random_bf16" -def test_mtp_torchair_correctness( +def mtp_torchair_correctness( sampling_config: SamplingParams, model_name: str, + graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE, ): example_prompts = [ "Hello, my name is", @@ -44,6 +46,11 @@ def test_mtp_torchair_correctness( "multistream_overlap_shared_expert": "True" }) as ref_llm: ref_outputs = ref_llm.generate(example_prompts, sampling_config) + + graph_mode_str = "PIECEWISE" + if graph_mode == CUDAGraphMode.FULL: + graph_mode_str = "FULL" + with VllmRunner(model_name, tensor_parallel_size=1, max_num_seqs=256, @@ -56,6 +63,8 @@ def test_mtp_torchair_correctness( }, enforce_eager=False, max_model_len=2000, + compilation_config=CompilationConfig( + cudagraph_mode=graph_mode_str), additional_config={ "torchair_graph_config": { "enabled": True, @@ -81,3 +90,17 @@ def test_mtp_torchair_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) + + +def test_mtp_torchair_correctness_piecewise( + sampling_config: SamplingParams, + model_name: str, +): + mtp_torchair_correctness(sampling_config, model_name) + + +def test_mtp_torchair_correctness_full( + sampling_config: SamplingParams, + model_name: str, +): + mtp_torchair_correctness(sampling_config, model_name, CUDAGraphMode.FULL) diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index 31e986d..8d55a94 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -448,6 +448,7 @@ class TestNPUWorker(TestBase): worker.compilation_config = MagicMock() worker.compilation_config.cudagraph_mode = MagicMock() mock_model_runner = MagicMock() + mock_decode_token_per_req = mock_model_runner.decode_token_per_req worker.model_runner = mock_model_runner # Test execute_dummy_batch @@ -455,7 +456,9 @@ class TestNPUWorker(TestBase): # Verify call mock_model_runner._dummy_run.assert_called_once_with( - num_tokens=1, uniform_decode=True, force_attention=False) + num_tokens=mock_decode_token_per_req, + uniform_decode=True, + force_attention=False) @patch("vllm_ascend.worker.worker_v1.envs_vllm") @patch("vllm_ascend.worker.worker_v1.logger") diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 46a9179..e40b5af 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -175,7 +175,7 @@ M = TypeVar("M", bound=AscendMLAMetadata) class AscendMLAMetadataBuilder: # Does this backend/builder support ACL Graphs for attention (default: no). aclgraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + AttentionCGSupport.UNIFORM_BATCH """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -209,7 +209,6 @@ class AscendMLAMetadataBuilder: got {self.decode_threshold}" self.reorder_batch_threshold = self.decode_threshold - if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( # Max sure there is enough for 8 full length request or at least @@ -427,10 +426,10 @@ class AscendMLAMetadataBuilder: sin=sin, cos=cos) else: - cos[:num_decodes, + cos[:num_decode_tokens, ...] = self.cos_cache[input_positions].unsqueeze( 1).unsqueeze(2) - sin[:num_decodes, + sin[:num_decode_tokens, ...] = self.sin_cache[input_positions].unsqueeze( 1).unsqueeze(2) @@ -442,8 +441,8 @@ class AscendMLAMetadataBuilder: max_seq_lens=max_seq_lens, attn_mask=common_attn_metadata.spec_attn_mask, actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin[:num_decodes, ...], - cos=cos[:num_decodes, ...]) + sin=sin[:num_decode_tokens, ...], + cos=cos[:num_decode_tokens, ...]) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -469,7 +468,10 @@ class AscendMLAMetadataBuilder: attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, model: Optional[nn.Module] = None, ): - if attn_state == AscendAttentionState.DecodeOnly: + if attn_state in { + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + }: attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -477,7 +479,7 @@ class AscendMLAMetadataBuilder: ) else: raise NotImplementedError( - "Currently we only support building dummy metadata for DecodeOnly state" + "Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state" ) attn_metadata.attn_state = attn_state @@ -955,7 +957,8 @@ class AscendMLAImpl(MLAAttentionImpl): if attn_metadata.attn_state in [ AscendAttentionState.SpecDecoding, - AscendAttentionState.ChunkedPrefill + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.DecodeOnly, ] and self.speculative_config is not None: # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill input_layout = "TND" diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 60c7e39..2ba6b25 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -245,7 +245,8 @@ def update_attn_params(update_stream, forward_context, runtime_shape): event.record(update_stream) -def update_mla_attn_params(update_stream, forward_context, runtime_shape): +def update_mla_attn_params(update_stream, forward_context, runtime_shape, + speculative_config): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. @@ -260,9 +261,19 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape): seq_lens_list, actual_seq_lengths, workspace, attn_output, softmax_lse) = param seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list - seq_lens_list = seq_lens_list + [0] * (runtime_shape - - len(seq_lens_list)) - + if speculative_config and speculative_config.method == "deepseek_mtp": + actual_seq_lengths = forward_context.attn_metadata[ + key].decode.actual_seq_lengths_q + spec_multiple = speculative_config.num_speculative_tokens + 1 + seq_lens_list = seq_lens_list + [0] * ( + runtime_shape // spec_multiple - len(seq_lens_list)) + actual_seq_lengths = [ + spec_multiple * (i + 1) + for i in range(runtime_shape // spec_multiple) + ] + else: + seq_lens_list = seq_lens_list + [0] * (runtime_shape - + len(seq_lens_list)) with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 96bf679..50696b4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -345,6 +345,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.speculative_config.method, self.vllm_config, self.device, self) self.rejection_sampler = AscendRejectionSampler() + self.actual_seq_lengths_q = list( + range(self.decode_token_per_req, self.max_num_tokens + 1, + self.decode_token_per_req)) # Persistent batch. self.input_ids = torch.zeros(self.max_num_tokens, @@ -366,13 +369,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.vllm_config.model_config.use_mla and \ self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.cos = torch.ones(self.max_num_reqs, + self.cos = torch.ones(self.max_num_reqs * + self.decode_token_per_req, 1, 1, rope_dim, dtype=self.dtype, device=self.device) - self.sin = torch.zeros(self.max_num_reqs, + self.sin = torch.zeros(self.max_num_reqs * + self.decode_token_per_req, 1, 1, rope_dim, @@ -1554,7 +1559,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.vllm_config.model_config.use_mla: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, - maybe_padded_num_tokens) + maybe_padded_num_tokens, + self.speculative_config) else: update_attn_params(self.update_stream, forward_context, maybe_padded_num_tokens) @@ -2255,7 +2261,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): block_table_tensor = self.input_batch.block_table[ kv_cache_group_id].get_device_tensor() common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc=torch.tensor( + [0] + self.actual_seq_lengths_q[:num_reqs], + device=self.device, + dtype=torch.int32), query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], seq_lens_cpu=self.seq_lens_cpu, @@ -2275,12 +2284,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): cos=self.cos, sin=self.sin, ) + attn_state = AscendAttentionState.DecodeOnly + if self.speculative_config and \ + self.speculative_config.method == "deepseek_mtp": + attn_state = AscendAttentionState.SpecDecoding for attn_group in self.attn_groups[kv_cache_group_id]: builder = attn_group.get_metadata_builder() attn_metadata_i = builder.build_for_graph_capture( - common_attn_metadata, AscendAttentionState.DecodeOnly, - self.get_model()) + common_attn_metadata, attn_state, self.get_model()) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -2301,7 +2313,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.vllm_config.model_config.use_mla: # FIXME: Try using `auto_dispatch_capture=True` update_mla_attn_params(self.update_stream, forward_context, - positions.shape[0]) + positions.shape[0], + self.speculative_config) else: update_attn_params(self.update_stream, forward_context, positions.shape[0]) @@ -3388,23 +3401,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): CUDAGraphMode.FULL_DECODE_ONLY) logger.warning(msg) - # check that if spec-decode + decode full-graphs is supported - if (aclgraph_mode.decode_mode() == CUDAGraphMode.FULL - and self.uniform_decode_query_len > 1 and min_ag_support.value - < AttentionCGSupport.UNIFORM_BATCH.value): - msg = (f"CUDAGraphMode.{aclgraph_mode.name} is not supported" - f" with spec-decode for attention backend " - f"{min_ag_builder_name} (support: {min_ag_support})") - if splitting_ops_contain_attention: - msg += "; setting cudagraph_mode=PIECEWISE" - aclgraph_mode = self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE - else: - msg += "; setting cudagraph_mode=NONE" - aclgraph_mode = self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.NONE - logger.warning(msg) - # double check that we can support full graph if they are requested # even after automatic downgrades if aclgraph_mode.has_full_cudagraphs() \ diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index e3ced0a..d26c077 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -361,9 +361,10 @@ class NPUWorker(WorkerBase): def execute_dummy_batch(self) -> None: force_attention = self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY - self.model_runner._dummy_run(num_tokens=1, - uniform_decode=True, - force_attention=force_attention) + self.model_runner._dummy_run( + num_tokens=self.model_runner.decode_token_per_req, + uniform_decode=True, + force_attention=force_attention) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment."""