diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index fabfb7b4..10d6528c 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -256,6 +256,57 @@ class TestAscendMLAMetadataBuilder(TestBase): builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.chunked_prefill_enabled) + @patch('vllm.distributed.parallel_state.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_ascend_mla_metadata_builder_build_full_graph( + self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group): + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + + mock_dcp.world_size = 1 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 1 + dcp_group.device_group = MagicMock() + mock_get_dcp_group.return_value = dcp_group + + mock_spec_config = MagicMock() + mock_spec_config.num_speculative_tokens = 1 + mock_spec_config.disable_padded_drafter_batch = True + mock_vllm_config.speculative_config = mock_spec_config + + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) + common_metadata = MagicMock() + model = MagicMock() + common_metadata.graph_pad_size = 8 + common_metadata.num_reqs = 4 + common_metadata.num_actual_tokens = 5 + common_metadata.max_query_len = 5 + common_metadata.seq_lens_cpu = torch.Tensor([9, 10, 8, 8]).int() + common_metadata.query_start_loc = torch.Tensor([0, 1, 2, 4, 5]).int() + common_metadata.query_start_loc_cpu = torch.Tensor([0, 1, 2, 4, + 5]).int() + common_metadata.positions = torch.Tensor([1, 2, 3, 4, 5, 6]).int() + block_table = torch.Tensor([[1, 0], [2, 0], [3, 0], [4, 0]]).int() + common_metadata.block_table_tensor = block_table + common_metadata.prefill_context_parallel_metadata = None + metadata = builder.build(0, common_metadata, model) + + self.assertEqual(metadata.decode.actual_seq_lengths_q, + [1, 2, 4, 5, 6, 6, 7, 8]) + self.assertEqual(metadata.decode.block_table.shape[0], 8) + @patch('vllm.distributed.parallel_state.get_dcp_group') @patch('vllm.distributed.parallel_state._DCP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @@ -307,6 +358,80 @@ class TestAscendMLAMetadataBuilder(TestBase): self.assertTrue(modified) input_batch.swap_states.assert_called_once_with(1, 2) + @patch('vllm.distributed.parallel_state.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_get_dcp_size, + mock_dcp, + mock_get_dcp_group): + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + mock_vllm_config.speculative_config = None + + mock_dcp.world_size = 1 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 1 + dcp_group.device_group = MagicMock() + mock_get_dcp_group.return_value = dcp_group + + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) + input_seq_lens = [1, 2, 4, 5] + expect_output = [1, 2, 4, 5, 6, 6, 7, 8] + num_reqs = 4 + num_reqs_pad_size = 4 + output_seq_lens = builder.pad_actual_seq_len_q_mtp_disable_pad( + num_reqs_pad_size, num_reqs, input_seq_lens) + self.assertEqual(output_seq_lens, expect_output) + + @patch('vllm.distributed.parallel_state.get_dcp_group') + @patch('vllm.distributed.parallel_state._DCP', + new_callable=lambda: MagicMock(spec=GroupCoordinator)) + @patch("vllm.distributed.get_decode_context_model_parallel_world_size", + return_value=1) + def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_get_dcp_size, + mock_dcp, + mock_get_dcp_group): + mock_vllm_config = MagicMock() + mock_vllm_config.model_config.max_model_len = 1024 + mock_vllm_config.model_config.get_head_size.return_value = 64 + mock_vllm_config.model_config.dtype = torch.float16 + mock_vllm_config.cache_config.block_size = 16 + mock_vllm_config.scheduler_config.max_num_seqs = 4 + mock_vllm_config.scheduler_config.decode_max_num_seqs = 4 + mock_vllm_config.scheduler_config.chunked_prefill_enabled = False + mock_device = 'cpu' + mock_vllm_config.speculative_config = None + + mock_dcp.world_size = 1 + dcp_group = MagicMock(spec=GroupCoordinator) + dcp_group.rank_in_group = 0 + dcp_group.world_size = 1 + dcp_group.device_group = MagicMock() + mock_get_dcp_group.return_value = dcp_group + common_metadata = MagicMock() + common_metadata.actual_seq_lengths_q = [2, 4, 6, 8] + + builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, + mock_device) + input_seq_lens = [2, 4, 6] + expect_output = [2, 4, 6, 8] + num_reqs = 3 + num_reqs_pad_size = 1 + output_seq_lens = builder.pad_actual_seq_len_q_mtp_enable_pad( + num_reqs_pad_size, num_reqs, input_seq_lens, common_metadata) + self.assertEqual(output_seq_lens, expect_output) + class TestAscendMLAMetadataBuilderBuild(TestBase): diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index 347fbd1c..2a9399fd 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -21,7 +21,9 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor, ForwardContext from tests.ut.base import TestBase -from vllm_ascend.compilation.acl_graph import ACLGraphEntry, ACLGraphWrapper +from vllm_ascend.compilation.acl_graph import ( + ACLGraphEntry, ACLGraphWrapper, get_mtp_graph_params, set_mtp_graph_params, + update_mtp_graph_params_workspaces) class TestACLGraphEntry(TestBase): @@ -718,3 +720,24 @@ class TestACLGraphWrapper(TestBase): unwrapped = wrapper.unwrap() self.assertEqual(unwrapped, self.mock_runnable) + + +class TestMTPGraphParams(TestBase): + + def test_set_mtp_graph_params(self): + with patch('vllm_ascend.compilation.acl_graph._mtp_graph_params', + new=None): + set_mtp_graph_params([4]) + from vllm_ascend.compilation.acl_graph import _mtp_graph_params + self.assertIsNotNone(_mtp_graph_params) + + @patch('vllm_ascend.compilation.acl_graph._mtp_graph_params') + def test_update_mtp_graph_params_workspaces(self, mtp_graph_params_mock): + mtp_graph_params_mock.workspaces = {4: 5} + update_mtp_graph_params_workspaces(4, 6) + self.assertEqual(mtp_graph_params_mock.workspaces[4], 6) + + @patch('vllm_ascend.compilation.acl_graph._mtp_graph_params') + def test_get_mtp_graph_params(self, mtp_graph_params_mock): + graph_params = get_mtp_graph_params() + self.assertIs(mtp_graph_params_mock, graph_params) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 066697c6..8c477dac 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -72,7 +72,8 @@ def set_ascend_forward_context( batch_descriptor: Optional[BatchDescriptor] = None, prefetch_stream: torch.npu.Stream = None, model_instance: torch.nn.Module = None, - weight_prefetch_method: Optional[WeightPrefetchMethod] = None): + weight_prefetch_method: Optional[WeightPrefetchMethod] = None, + is_mtp_model=False): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. @@ -158,6 +159,7 @@ def set_ascend_forward_context( forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled forward_context.model_instance = model_instance forward_context.weight_prefetch_method = weight_prefetch_method + forward_context.is_mtp_model = is_mtp_model # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. # It will be improved later by implementing operator fusion through the FX graph. diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5038fc95..5506b185 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -8,6 +8,7 @@ import torch.distributed as dist import torch_npu from torch import nn from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import (get_dcp_group, get_decode_context_model_parallel_rank, @@ -38,6 +39,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import (get_graph_params, + get_mtp_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod @@ -337,6 +339,74 @@ class AscendMLAMetadataBuilder: # better way of doing this return modified_batch + def pad_actual_seq_len_q_mtp_enable_pad(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q, + common_attn_metadata): + """ + Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request + in order to meet the requirement of npu_fused_infer_attention_score. + + In Torchair scenario, the lengths of the queries must be padded to the same length. + And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). + + For example: + batch_size=36, num_reqs_pad_size=2, num_reqs=16 + By default, each request should have inference 2 token, which means actual_seq_lengths_q should be + [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. + + However, mtp torchair + PD scenario, the actual_seq_lengths_q may be + [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. + In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. + after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] + """ + FIA_SEQ_LEN_LIMIT = 16 + need_padding = num_reqs_pad_size != 0 and \ + len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and \ + common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT + if need_padding: + padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + start_val = actual_seq_lengths_q[-1] + end_val = padding_seq_len_q[-1] + + num_step = len(padding_seq_len_q) + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == len(padding_seq_len_q) + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + else: + actual_seq_lengths_q = actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[ + num_reqs:num_reqs + num_reqs_pad_size] + + return actual_seq_lengths_q + + def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs, + actual_seq_lengths_q): + """ + Only use for acl full graph mode. + Pad the last element of the actual_seq_lengths_q equal to the TND(T) and + the num of dimensions equal to the batch_size of main model. + + For example: + batch_size = 8, num_reqs = 4, num_speculative_tokens = 1 + input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token) + After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8] + """ + need_padding = num_reqs_pad_size > 0 + if need_padding: + start_val = actual_seq_lengths_q[-1] + end_val = num_reqs + num_reqs_pad_size + num_step = num_reqs_pad_size + interpolated = np.round( + np.linspace(start_val, end_val, + num_step + 1)[1:]).astype(int).tolist() + assert interpolated[-1] == end_val + assert len(interpolated) == num_reqs_pad_size + actual_seq_lengths_q = actual_seq_lengths_q + interpolated + return actual_seq_lengths_q + def build( self, common_prefix_len: int, @@ -362,17 +432,25 @@ class AscendMLAMetadataBuilder: # it blocks on all previous kernels. device = self.device - block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + # If graph_pad_size > -1, mean is running in fullgraph mode. + graph_pad_size = common_attn_metadata.graph_pad_size + # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. + if graph_pad_size > num_reqs and self.speculative_config.disable_padded_drafter_batch: + block_table = ( + common_attn_metadata.block_table_tensor[:graph_pad_size]) + else: + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + # NOTE: Currently, MTP-fullgraph is incompatibility pcp if self.pcp_size > 1: num_decodes_flatten = num_decodes * self.decode_threshold block_table = common_attn_metadata.block_table_tensor[: num_decodes_flatten + num_prefills] - if num_actual_tokens_pcp_padded is None: num_actual_tokens_pcp_padded = num_actual_tokens + # NOTE: Currently, MTP-fullgraph is incompatibility pcp slot_mapping = common_attn_metadata.slot_mapping[: num_actual_tokens_pcp_padded] input_positions = common_attn_metadata.positions[: @@ -565,6 +643,11 @@ class AscendMLAMetadataBuilder: block_table = block_table[:num_decodes_flatten, ...] else: block_table = block_table[:num_decodes, ...] + # NOTE: Currently, MTP-fullgraph is incompatibility pcp + # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. + if graph_pad_size > num_decodes and \ + self.speculative_config.disable_padded_drafter_batch: + block_table = block_table[:graph_pad_size, ...] seq_lens_list = seq_lens.tolist() if num_computed_tokens_of_pcp_dcp is not None: @@ -586,6 +669,52 @@ class AscendMLAMetadataBuilder: else: cp_seq_len, batch_seq_mask = None, None + if graph_pad_size > num_reqs: + if self.speculative_config.disable_padded_drafter_batch: + num_reqs_pad_size = graph_pad_size - num_reqs + actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad( + num_reqs_pad_size, num_reqs, actual_seq_lengths_q) + seq_lens_list = seq_lens_list + [0] * (graph_pad_size - \ + num_decodes) + num_block_pad_size = graph_pad_size - block_table.shape[0] + if num_block_pad_size > 0: + block_table_padding = torch.zeros( + (num_block_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat( + [block_table, block_table_padding], dim=0) + else: + num_token_pad_size = graph_pad_size - num_decode_tokens + num_reqs_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - num_reqs) + num_block_table_pad_size = ( + graph_pad_size // + common_attn_metadata.decode_token_per_req - + num_decodes) + seq_lens_list = seq_lens.tolist() + [0] * num_reqs_pad_size + slot_padding = torch.full((num_token_pad_size, ), + PAD_SLOT_ID, + dtype=slot_mapping.dtype, + device=slot_mapping.device) + slot_mapping = torch.cat([slot_mapping, slot_padding]) + block_table_padding = torch.zeros( + (num_block_table_pad_size, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], + dim=0) + position_padding = torch.zeros( + num_token_pad_size, + dtype=input_positions.dtype, + device=input_positions.device) + input_positions = torch.cat( + [input_positions, position_padding]) + actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad( + num_reqs_pad_size, num_reqs, actual_seq_lengths_q, + common_attn_metadata) + # TODO: After the fullgraph supports MTP, the if branch needs to deleted assert self.cos_cache is not None assert self.sin_cache is not None @@ -1267,8 +1396,11 @@ class AscendMLAImpl(MLAAttentionImpl): "actual_seq_lengths": actual_seq_lengths, "actual_seq_lengths_kv": decode_meta.seq_lens_list, } - graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() + if forward_context.is_mtp_model: + graph_params = get_mtp_graph_params() + else: + graph_params = get_graph_params() if forward_context.capturing: stream = torch_npu.npu.current_stream() diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 6aaccc63..9b057011 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -241,7 +241,10 @@ def update_attn_params(update_stream, forward_context, runtime_shape): def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config): - graph_params = get_graph_params() + if forward_context.is_mtp_model: + graph_params = get_mtp_graph_params() + else: + graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. with torch.npu.stream(update_stream): @@ -257,7 +260,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, softmax_lse) = param seq_lens_list = forward_context.attn_metadata[ key].decode.seq_lens_list - if speculative_config and speculative_config.method == "deepseek_mtp": + if speculative_config and speculative_config.method == "deepseek_mtp" \ + and not forward_context.is_mtp_model: actual_seq_lengths = forward_context.attn_metadata[ key].decode.actual_seq_lengths_q spec_multiple = speculative_config.num_speculative_tokens + 1 @@ -267,6 +271,13 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, spec_multiple * (i + 1) for i in range(runtime_shape // spec_multiple) ] + elif forward_context.is_mtp_model: + actual_seq_lengths = forward_context.attn_metadata[ + key].decode.actual_seq_lengths_q + block_table = forward_context.attn_metadata[ + key].decode.block_table + seq_lens_list = seq_lens_list + [0] * ( + len(actual_seq_lengths) - len(seq_lens_list)) else: seq_lens_list = seq_lens_list + [0] * (runtime_shape - len(seq_lens_list)) @@ -443,3 +454,32 @@ def update_graph_params_workspaces(num_tokens: int, workspace: int): def get_graph_params(): return _graph_params + + +_mtp_graph_params: Optional[GraphParams] = None + + +def set_mtp_graph_params(aclgraph_capture_sizes: set[int]): + global _mtp_graph_params + if _mtp_graph_params is not None: + raise ValueError("MTPGraph parameters have already been set!") + _mtp_graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any): + global _mtp_graph_params + if _mtp_graph_params is not None: + _mtp_graph_params.workspaces[num_tokens] = workspace + + +def get_mtp_graph_params(): + return _mtp_graph_params diff --git a/vllm_ascend/patch/worker/patch_deepseek_mtp.py b/vllm_ascend/patch/worker/patch_deepseek_mtp.py index 5f918b2d..c4df4d50 100644 --- a/vllm_ascend/patch/worker/patch_deepseek_mtp.py +++ b/vllm_ascend/patch/worker/patch_deepseek_mtp.py @@ -10,6 +10,12 @@ from vllm.model_executor.models.deepseek_mtp import \ from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer from vllm.model_executor.models.utils import maybe_prefix +from vllm_ascend.utils import vllm_version_is + +if vllm_version_is("0.11.0"): + from vllm.compilation.decorators import support_torch_compile + from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP + class SharedHead(nn.Module): @@ -51,4 +57,38 @@ def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None: topk_indices_buffer) +def predictor_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + spec_step_index: int = 0, +) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +# Patch this only for aclgraph support, as this is not support in vLLM 0.11.0 +@support_torch_compile +class AscendDeepSeekMTP(DeepSeekMTP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init +if vllm_version_is("0.11.0"): + DeepSeekMultiTokenPredictorLayer.forward = predictor_forward diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 627411fe..df446537 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,5 +1,5 @@ import importlib -from typing import Optional +from typing import Optional, Union import numpy as np import torch @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from vllm.config import (CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import BatchDescriptor +from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model_loader @@ -32,7 +32,11 @@ from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import set_ascend_forward_context +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, + set_mtp_graph_params, + update_mla_attn_params) from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, prefill_context_parallel_enable, @@ -52,9 +56,14 @@ logger = init_logger(__name__) PADDING_SLOT_ID = -1 +_deepseek_mtp_path = "vllm.model_executor.models.deepseek_mtp" +_deepseek_mtp_model = "DeepSeekMTP" +if vllm_version_is("0.11.0"): + _deepseek_mtp_path = "vllm_ascend.patch.worker.patch_deepseek_mtp" + _deepseek_mtp_model = "AscendDeepSeekMTP" + _MTP_MODELS = { - "DeepseekV3ForCausalLM": - ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), + "DeepseekV3ForCausalLM": (_deepseek_mtp_path, _deepseek_mtp_model), "Qwen3NextForCausalLM": ("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP") } @@ -75,6 +84,9 @@ def _load_model(architecture): class MtpProposer(Proposer): + # TODO: Find out why ModelRunner does not this explicit typing? + model: Union[nn.Module, ACLGraphWrapper] + def __init__( self, vllm_config: VllmConfig, @@ -203,6 +215,15 @@ class MtpProposer(Proposer): process_weights_after_loading(self.model, draft_model_config, target_device) + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( + ): + self.update_stream: torch.npu.Stream = torch.npu.Stream() + set_mtp_graph_params( + self.vllm_config.compilation_config.cudagraph_capture_sizes) + self.model = ACLGraphWrapper(self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL) + @torch.inference_mode() def dummy_run(self, num_tokens: int, @@ -222,12 +243,55 @@ class MtpProposer(Proposer): moe_comm_type = self.runner._select_moe_comm_method( num_tokens, with_prefill) - attn_metadata = None + if skip_attn: + attn_metadata = None + elif aclgraph_runtime_mode == CUDAGraphMode.FULL: + if len(self.runner.attn_groups) > 0: + num_computed_tokens_cpu = ( + self.runner.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs]) + common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=self.runner.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.runner. + query_start_loc_cpu[:num_reqs + 1], + seq_lens_cpu=self.runner.seq_lens_cpu, + seq_lens=self.runner.seq_lens_cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=self.num_speculative_tokens + 1, + num_computed_tokens_cpu=num_computed_tokens_cpu, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=self.runner.input_batch.block_table[0]. + get_device_tensor()[:num_reqs], + slot_mapping=self.runner.input_batch.block_table[0]. + slot_mapping, + positions=self.runner.positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + cos=self.runner.cos, + sin=self.runner.sin, + ) + + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_mtp = builder.build_for_graph_capture( + common_attn_metadata, AscendAttentionState.SpecDecoding, + self.runner.get_model()) + attn_metadata = {} + for layer_name in self.attn_layer_name: + attn_metadata[layer_name] = attn_metadata_mtp + else: + attn_metadata = None + else: + attn_metadata = None input_ids = self.input_ids[:num_tokens] positions = self.positions[:num_tokens] previous_hidden_states = self.hidden_states[:num_tokens] - for _ in range(self.num_speculative_tokens): + for i in range(self.num_speculative_tokens): + if i > 0: + aclgraph_runtime_mode = CUDAGraphMode.NONE with set_ascend_forward_context( attn_metadata, self.vllm_config, @@ -239,10 +303,19 @@ class MtpProposer(Proposer): in_profile_run=self.runner.in_profile_run, num_actual_tokens=0, aclgraph_runtime_mode=aclgraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + is_mtp_model=True): self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ + not forward_context.capturing: + if self.vllm_config.model_config.use_mla: + update_mla_attn_params( + self.update_stream, forward_context, + positions.shape[0], + self.vllm_config.speculative_config) if with_prefill: break @@ -324,6 +397,7 @@ class MtpProposer(Proposer): common_attn_metadata.query_start_loc = \ query_start_loc_pcp_full[:num_reqs + 1] if self.speculative_config.disable_padded_drafter_batch: + # NOTE: Currently, MTP-fullgraph is incompatibility with pcp token_indices_to_sample = None common_attn_metadata, token_indices =\ self._prepare_inputs( @@ -358,6 +432,8 @@ class MtpProposer(Proposer): long_seq_metadata=long_seq_metadata, num_prefill_reqs=num_prefill_reqs, num_decode_reqs=num_decode_reqs, + scheduler_output=scheduler_output, + num_scheduled_tokens=num_scheduled_tokens, ) return draft_token_ids @@ -460,6 +536,13 @@ class MtpProposer(Proposer): token_indices = torch.from_numpy(token_indices_np).to( device, non_blocking=True) + common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_( + common_attn_metadata.slot_mapping[token_indices]) + common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1) + + # NOTE: Currently positions and seq_lens are not used in mla_v1 forward + # so we do not need to fixed them. But if they are used in the future, + # we should fixed them. spec_common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), @@ -472,7 +555,7 @@ class MtpProposer(Proposer): num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], + slot_mapping=common_attn_metadata.slot_mapping, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, positions=common_attn_metadata.positions[token_indices], attn_mask=self.runner.attn_mask, @@ -502,6 +585,8 @@ class MtpProposer(Proposer): long_seq_metadata=None, num_prefill_reqs=0, num_decode_reqs=0, + scheduler_output: SchedulerOutput = None, + num_scheduled_tokens: int = 0, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -585,14 +670,11 @@ class MtpProposer(Proposer): assert self.runner is not None - builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata_mtp = builder.build(0, common_attn_metadata, - self.runner.get_model()) - attn_metadata = {} - for layer_name in self.attn_layer_name: - attn_metadata[layer_name] = attn_metadata_mtp - - if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]: + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( + ) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph( + num_scheduled_tokens) + elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]: # Acl graph mode, add padding to the batch size num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: @@ -609,19 +691,39 @@ class MtpProposer(Proposer): moe_comm_type = self.runner._select_moe_comm_method( num_input_tokens, with_prefill) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=False) + + if scheduler_output: + max_query_len = common_attn_metadata.max_query_len + uniform_decode = (max_query_len in list( + range(1, self.num_speculative_tokens + + 2))) and (scheduler_output.total_num_scheduled_tokens + == self.runner.input_batch.num_reqs * + (self.num_speculative_tokens + 1)) + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=uniform_decode) + else: + batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=False) aclgraph_runtime_mode, batch_descriptor = \ self.runner.aclgraph_dispatcher.dispatch(batch_descriptor) - if aclgraph_runtime_mode not in [ - CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE - ]: - # Fallback to piecewise graph, when acl full graph is enabled - logger.debug( - "Currently the eagle proposer only supports cudagraph_mode " - f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} " - "to CUDAGraphMode.PIECEWISE") - aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE + + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( + ) and aclgraph_runtime_mode == CUDAGraphMode.FULL: + graph_pad_size = num_input_tokens + else: + # Currently, runner.graph_pad_size will always be -1. + graph_pad_size = self.runner.graph_pad_size + + # If use fullgraph and disable_padded_drafter_batch=True, We need to + # update the graph_pad_size in common_attn_metadata, to tell the + # builder padding some elements. + common_attn_metadata.graph_pad_size = graph_pad_size + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata_mtp = builder.build(0, common_attn_metadata, + self.runner.get_model()) + attn_metadata = {} + for layer_name in self.attn_layer_name: + attn_metadata[layer_name] = attn_metadata_mtp for step in range(self.num_speculative_tokens): with set_ascend_forward_context( @@ -635,7 +737,8 @@ class MtpProposer(Proposer): aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, in_profile_run=self.runner.in_profile_run, - num_actual_tokens=num_tokens): + num_actual_tokens=num_tokens, + is_mtp_model=True): with ProfileExecuteDuration().capture_async('mtp_forward'): model_kwargs = {} model_kwargs["attn_metadata"] = attn_metadata @@ -644,6 +747,13 @@ class MtpProposer(Proposer): input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens]) + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + if self.vllm_config.model_config.use_mla: + update_mla_attn_params( + self.update_stream, forward_context, + num_input_tokens, + self.vllm_config.speculative_config) num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): @@ -699,12 +809,21 @@ class MtpProposer(Proposer): input_ids = draft_token_ids_list[-1].int() positions += 1 - attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ - 1:batch_size + 1].tolist() - attn_metadata_i.decode.cos = builder.cos_cache[ - positions].unsqueeze(1).unsqueeze(2) - attn_metadata_i.decode.sin = builder.sin_cache[ - positions].unsqueeze(1).unsqueeze(2) + # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. + if self.speculative_config.disable_padded_drafter_batch or \ + aclgraph_runtime_mode != CUDAGraphMode.FULL: + attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ + 1:batch_size + 1].tolist() + if aclgraph_runtime_mode == CUDAGraphMode.FULL: + attn_metadata_i.decode.actual_seq_lengths_q = \ + builder.pad_actual_seq_len_q_mtp_disable_pad( + graph_pad_size - batch_size, + batch_size, + attn_metadata_i.decode.actual_seq_lengths_q) + attn_metadata_i.decode.cos = builder.cos_cache[ + positions].unsqueeze(1).unsqueeze(2) + attn_metadata_i.decode.sin = builder.sin_cache[ + positions].unsqueeze(1).unsqueeze(2) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex # to remove such requests from the batch, we keep them in the batch @@ -735,6 +854,10 @@ class MtpProposer(Proposer): self.positions[:batch_size] = clamped_positions self.hidden_states[:hidden_states.shape[0]] = hidden_states attn_metadata_i.slot_mapping[:batch_size] = slot_mapping + if self.speculative_config.disable_padded_drafter_batch: + self.positions[batch_size:num_input_tokens] = 0 + self.input_ids[batch_size:num_input_tokens] = 0 + self.hidden_states[batch_size:num_input_tokens].fill_(0) if attn_metadata_i.prefill is not None: attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens @@ -751,12 +874,19 @@ class MtpProposer(Proposer): attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( ) + decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list + if aclgraph_runtime_mode == CUDAGraphMode.FULL and \ + self.speculative_config.disable_padded_drafter_batch: + attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [ + 0 + ] * (graph_pad_size - len(decode_seq_lens_list)) attn_metadata_i.decode.input_positions = self.positions[: num_input_tokens] attn_metadata_i.decode.max_seq_lens += 1 attn_metadata_i.decode.max_seq_lens = min( attn_metadata_i.decode.max_seq_lens, self.runner.model_config.max_model_len) + torch.npu.synchronize() # mtp>1: [batch_size, k] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) @@ -915,6 +1045,9 @@ class MtpProposer(Proposer): total_num_tokens = query_start_loc_cpu[-1].item() token_indices = self.arange[:total_num_tokens] + # NOTE: Currently positions and seq_lens are not used in mla_v1 forward + # so we do not need to fixed them. But if they are used in the future, + # we should fixed them. spec_common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, query_start_loc_cpu=query_start_loc_cpu, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 817eaa3b..544b3edc 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3112,7 +3112,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.drafter.dummy_run( num_tokens=num_tokens, with_prefill=with_prefill, - skip_attn=True, num_reqs=num_reqs, num_tokens_across_dp=num_tokens_across_dp, aclgraph_runtime_mode=aclgraph_runtime_mode,