diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py index bbb6e01..ed5aa55 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py @@ -39,7 +39,7 @@ def mtp_correctness( tensor_parallel_size=1, gpu_memory_utilization=0.7, max_model_len=256, - enforce_eager=True) as ref_llm: + enforce_eager=False) as ref_llm: ref_outputs = ref_llm.generate(example_prompts, sampling_config) with VllmRunner( @@ -53,7 +53,7 @@ def mtp_correctness( "method": "deepseek_mtp", "num_speculative_tokens": num_speculative_tokens, }, - enforce_eager=True, + enforce_eager=False, max_model_len=2000, additional_config={"ascend_scheduler_config": { "enabled": False diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index a1df85b..5fdc202 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -186,6 +186,34 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + + ascend_config = MagicMock() + with patch("vllm_ascend.attention.mla_v1.get_ascend_config", + return_value=ascend_config): + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) + + self.assertEqual(builder.block_size, + mock_vllm_config.cache_config.block_size) + self.assertEqual( + builder.chunked_prefill_enabled, + mock_vllm_config.scheduler_config.chunked_prefill_enabled) + + def test_ascend_mla_metadata_builder_spec_decode(self): + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + mock_spec_config = MagicMock() + mock_spec_config.num_speculative_tokens = 3 + mock_vllm_config.speculative_config = mock_spec_config + ascend_config = MagicMock() with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): @@ -208,6 +236,8 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + with patch("vllm_ascend.attention.mla_v1.get_ascend_config", return_value=ascend_config): builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index 0990752..0e0150c 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -190,6 +190,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + ascend_config = MagicMock() ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True @@ -217,6 +219,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): ascend_config.torchair_graph_config = MagicMock() ascend_config.torchair_graph_config.enabled = True + mock_vllm_config.speculative_config = None + builder = AscendMLATorchairMetadataBuilder(None, None, mock_vllm_config, mock_device) @@ -252,6 +256,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config", return_value=ascend_config): builder = AscendMLATorchairMetadataBuilder(None, None, @@ -288,6 +294,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + builder = AscendMLATorchairMetadataBuilder(None, None, mock_vllm_config, mock_device) @@ -309,6 +317,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + builder = AscendMLATorchairMetadataBuilder(None, None, mock_vllm_config, mock_device) @@ -331,6 +341,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.chunked_prefill_enabled = False mock_device = 'cpu' + mock_vllm_config.speculative_config = None + builder = AscendMLATorchairMetadataBuilder(None, None, mock_vllm_config, mock_device) @@ -357,6 +369,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): mock_vllm_config.model_config.dtype = torch.float16 mock_device = 'cpu' + mock_vllm_config.speculative_config = None + builder = AscendMLATorchairMetadataBuilder( None, None, @@ -424,6 +438,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase): model = MagicMock(spec=nn.Module) model.model = MagicMock(spec=nn.Module) + mock_vllm_config.speculative_config = None + builder = AscendMLATorchairMetadataBuilder( None, None, diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index aa6c597..d287bad 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -187,7 +187,14 @@ class AscendMLAMetadataBuilder: self.block_size - 1) // self.block_size self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + self.speculative_config = vllm_config.speculative_config self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -275,7 +282,6 @@ class AscendMLAMetadataBuilder: num_actual_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - # TODO(xyx): remove the if condition after mla supports torch mode speculative decoding num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) assert num_decodes + num_prefills == num_reqs diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 4749bac..0a96b25 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -8,7 +8,7 @@ from torchair import patch_for_hcom from vllm.attention.layer import Attention from vllm.config import (VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import get_forward_context +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, set_default_torch_dtype) @@ -363,8 +363,14 @@ class MtpProposer(Proposer): not self.runner.with_prefill if is_running_torchair: + # Torchair graph mode, padding is same as the main model num_input_tokens = self.runner.graph_pad_size + elif (self.runner.use_aclgraph + and num_tokens <= self.runner.aclgraph_batch_sizes[-1]): + # Acl graph mode, add padding to the batch size + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: + # Eager mode, no padding needed num_input_tokens = num_tokens seq_lens = target_positions[last_token_indices] + 1 @@ -410,7 +416,7 @@ class MtpProposer(Proposer): # TODO: adapt enable_dbo later (num_input_tokens, num_tokens_across_dp, with_prefill, _) = self.runner._sync_metadata_across_dp( - num_tokens, self.runner.with_prefill, False) + num_input_tokens, self.runner.with_prefill, False) else: # torchair mode can reuse self.runner.num_tokens_across_dp num_tokens_across_dp = self.runner.num_tokens_across_dp @@ -418,6 +424,10 @@ class MtpProposer(Proposer): moe_comm_method = self.runner._select_moe_comm_method( num_input_tokens, with_prefill) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=False) + aclgraph_runtime_mode, batch_descriptor = \ + self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) for step in range(self.num_speculative_tokens): with set_ascend_forward_context( @@ -428,6 +438,7 @@ class MtpProposer(Proposer): num_tokens_across_dp=num_tokens_across_dp, reserved_mc2_mask=self.runner.reserved_mc2_mask, moe_comm_method=moe_comm_method, + aclgraph_runtime_mode=aclgraph_runtime_mode, in_profile_run=self.runner.in_profile_run, num_actual_tokens=num_tokens): with ProfileExecuteDuration().capture_async('mtp_forward'): diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index e43d912..c9c2d61 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -52,6 +52,10 @@ class NPUTorchairModelRunner(NPUModelRunner): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp super().__init__(vllm_config, device) + if self.speculative_config: + self.actual_seq_lengths_q = list( + range(self.decode_token_per_req, self.max_num_tokens + 1, + self.decode_token_per_req)) self.attn_metadata_builder = self.attn_backend.get_builder_cls()( None, None, vllm_config, device) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7e553ae..9f99abc 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -306,17 +306,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.spec_attn_mask = None self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer]] = None - self.actual_seq_lengths_q = [] + self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: spec_token_num = self.speculative_config.num_speculative_tokens assert spec_token_num > 0 self.decode_token_per_req = 1 + spec_token_num - self.actual_seq_lengths_q = [ - len for len in - range(self.decode_token_per_req, self.max_num_tokens + - 1, self.decode_token_per_req) - ] self.spec_attn_mask = torch.triu(torch.ones(2048, 2048, dtype=torch.bool),