From 5b179c53f1669d9dfceeba1ff105423678be7004 Mon Sep 17 00:00:00 2001 From: Yizhou <136800916+yiz-liu@users.noreply.github.com> Date: Wed, 10 Dec 2025 20:11:09 +0800 Subject: [PATCH] [FEAT] Support DeepSeek-V3.2 with `FULL_DECODE_ONLY` mode (#4706) ### What this PR does / why we need it? The first commit support `FULL_DECODE_ONLY`: - Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for slicing slots and positions, ensuring fixed tensor shapes. - Implement padding logic for `query_start_loc` in `NPUModelRunner` to support uniform decode in full graph mode, aligning with GPU runner behavior. - Adjust MLA cosine cache allocation to occur independently of graph mode and switch to using device-resident sequence lengths for attention metadata. - Remove redundant slicing of hidden states and outputs in `AscendSFAImpl` and optimize `sin`/`cos` cache updates. The second commit take MTP into account: - Update `AscendSFAMetadataBuilder` to use `num_input_tokens` for slicing slots and positions, ensuring fixed tensor shapes. - Implement padding logic for `query_start_loc` in `NPUModelRunner` to support uniform decode in full graph mode, aligning with GPU runner behavior. - Adjust MLA cosine cache allocation to occur independently of graph mode and switch to using device-resident sequence lengths for attention metadata. - Remove redundant slicing of hidden states and outputs in `AscendSFAImpl` and optimize `sin`/`cos` cache updates. And the rest of them are just bugfix. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? Test cases needed. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: Yizhou Liu --- tests/ut/attention/test_sfa_v1.py | 6 +++ vllm_ascend/attention/attention_v1.py | 23 +-------- vllm_ascend/attention/sfa_v1.py | 66 +++++++++++++----------- vllm_ascend/compilation/acl_graph.py | 7 ++- vllm_ascend/spec_decode/mtp_proposer.py | 28 +++++----- vllm_ascend/worker/model_runner_v1.py | 68 ++++++++++++++++++++----- 6 files changed, 120 insertions(+), 78 deletions(-) diff --git a/tests/ut/attention/test_sfa_v1.py b/tests/ut/attention/test_sfa_v1.py index 3db637c1..06441306 100644 --- a/tests/ut/attention/test_sfa_v1.py +++ b/tests/ut/attention/test_sfa_v1.py @@ -124,6 +124,9 @@ class TestAscendSFAMetadataBuilder(TestBase): common_attn_metadata.attn_mask = None common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill common_attn_metadata.block_table_tensor = torch.randn(100, 4) + common_attn_metadata.cos = None + common_attn_metadata.sin = None + common_attn_metadata.num_input_tokens = 100 model = MagicMock() model.model.layers = [MagicMock() for _ in range(10)] @@ -166,6 +169,9 @@ class TestAscendSFAMetadataBuilder(TestBase): common_attn_metadata.attn_mask = None common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill common_attn_metadata.block_table_tensor = torch.randn(100, 4) + common_attn_metadata.cos = None + common_attn_metadata.sin = None + common_attn_metadata.num_input_tokens = 100 model = MagicMock() model.model.layers = [MagicMock() for _ in range(10)] diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 8ef50a43..d97acf65 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -297,30 +297,11 @@ class AscendAttentionMetadataBuilder: slot_mapping = common_attn_metadata.slot_mapping[: num_actual_tokens_pcp_padded] - # slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] + attn_mask = common_attn_metadata.attn_mask attn_state = common_attn_metadata.attn_state - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: - num_reqs - + 1] - if common_attn_metadata.num_input_tokens > num_actual_tokens: - padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens - seq_lens = torch.cat([ - seq_lens, - torch.tensor([padded_num_tokens - ]).to(seq_lens.device).to(seq_lens.dtype) - ]) - block_table_padding = torch.zeros( - (padded_num_tokens, ) + block_table.shape[1:], - dtype=block_table.dtype, - device=block_table.device) - block_table = torch.cat([block_table, block_table_padding], dim=0) - query_start_loc_cpu = torch.cat([ - query_start_loc_cpu, - torch.tensor([query_start_loc_cpu[-1] + padded_num_tokens]).to( - query_start_loc_cpu.device).to(query_start_loc_cpu.dtype) - ]) + # TODO: Yet another unnecessary H2D while we already have a query_start_loc on device query_start_loc = query_start_loc_cpu.pin_memory().to( self.device, non_blocking=True) is_causal_pooling = None diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 550f08ed..cbf5833b 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -163,16 +163,12 @@ class AscendSFAMetadataBuilder: ) -> AscendSFAMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - device = self.device + num_input_tokens = common_attn_metadata.num_input_tokens - block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping[: - num_actual_tokens].to( - device, - non_blocking=True) + block_table = common_attn_metadata.block_table_tensor[:num_reqs] + slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens] input_positions = common_attn_metadata.positions[: - num_actual_tokens].long( + num_input_tokens].long( ) query_start_loc = common_attn_metadata.query_start_loc query_lens = query_start_loc[1:] - query_start_loc[:-1] @@ -189,15 +185,23 @@ class AscendSFAMetadataBuilder: self.sin_cache = self.sin_cache.to( # type: ignore self.model_config.dtype) # type: ignore - cum_query_lens = query_start_loc_cpu[1:num_reqs + 1].to( - torch.int32).to(device, non_blocking=True) - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs].to( - torch.int32).to(device, non_blocking=True) + cum_query_lens = common_attn_metadata.query_start_loc[1:num_reqs + 1] + seq_lens = common_attn_metadata.seq_lens[:num_reqs] - cos = self.cos_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) + cos = common_attn_metadata.cos + sin = common_attn_metadata.sin + + assert self.cos_cache is not None and self.sin_cache is not None + new_cos = self.cos_cache[input_positions][:, None, None] + new_sin = self.sin_cache[input_positions][:, None, None] + + if (cos is not None and sin is not None + and num_input_tokens <= cos.shape[0] + and num_input_tokens <= sin.shape[0]): + cos[:num_input_tokens] = new_cos + sin[:num_input_tokens] = new_sin + else: + cos, sin = new_cos, new_sin sfa_cp_context = None if self.enable_sfa_cp: @@ -268,8 +272,8 @@ class AscendSFAMetadataBuilder: attn_mask=common_attn_metadata.attn_mask, attn_state=common_attn_metadata.attn_state, block_tables=block_table, - sin=sin, - cos=cos, + sin=sin[:num_input_tokens], + cos=cos[:num_input_tokens], sfa_cp_context=sfa_cp_context) def build_for_graph_capture( @@ -278,7 +282,10 @@ class AscendSFAMetadataBuilder: attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, model: Optional[nn.Module] = None, ): - if attn_state == AscendAttentionState.DecodeOnly: + if attn_state in { + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding + }: attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -681,29 +688,29 @@ class AscendSFAImpl(MLAAttentionImpl): self.q_proj.quant_bias = None torch.npu.empty_cache() - def _sfa_preprocessc_decode( + def _sfa_preprocess_decode( self, hidden_states: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, need_gather_q_kv: bool, - num_actual_tokens: int, + num_input_tokens: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv) k_nope, k_pe = kv_cache[0], kv_cache[1] ql_nope = torch.empty( - (num_actual_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), + (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device, ) q_pe = torch.empty( - (num_actual_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), + (num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device, ) q_c = torch.empty( - (num_actual_tokens, self.q_lora_rank), + (num_input_tokens, self.q_lora_rank), dtype=hidden_states.dtype, device=hidden_states.device, ) @@ -721,7 +728,7 @@ class AscendSFAImpl(MLAAttentionImpl): self.W_UK_T, k_nope, k_pe, - attn_metadata.slot_mapping[:num_actual_tokens].flatten(), + attn_metadata.slot_mapping, quant_scale0=self.quant_scale0, quant_offset0=self.quant_offset0, bias0=self.quant_bias_qkv, @@ -761,25 +768,22 @@ class AscendSFAImpl(MLAAttentionImpl): reach_layer_for_shared_weight_series(self.o_proj) return output.fill_(0) has_prefill = attn_metadata.has_prefill - num_actual_tokens = attn_metadata.num_actual_tokens cos = attn_metadata.cos sin = attn_metadata.sin actual_seq_lengths_query = attn_metadata.cum_query_lens actual_seq_lengths_key = attn_metadata.seq_lens - hidden_states = hidden_states[:num_actual_tokens] if self.enable_sfa_cp: need_gather_q_kv = False # Inputs and outputs may be padded for CUDA graphs output_padded = output - output = output[:num_actual_tokens] if self.enable_mlapo and not forward_context.with_prefill: - hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocessc_decode( + hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, need_gather_q_kv=need_gather_q_kv, - num_actual_tokens=num_actual_tokens, + num_input_tokens=attn_metadata.num_input_tokens, ) else: assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." @@ -802,7 +806,7 @@ class AscendSFAImpl(MLAAttentionImpl): if has_prefill: wait_for_kv_layer_from_connector(layer_name) - slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens] + slot_mapping = attn_metadata.slot_mapping slot_mapping_cp = None if self.enable_sfa_cp: assert attn_metadata.sfa_cp_context is not None diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 4ddc1d85..ab611445 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -273,6 +273,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, key].decode.actual_seq_lengths_q block_table = forward_context.attn_metadata[ key].decode.block_table + # TODO: This is a hack and should be fixed in the future. + if speculative_config.disable_padded_drafter_batch: + block_table = block_table[:len(actual_seq_lengths)] seq_lens_list = seq_lens_list + [0] * ( len(actual_seq_lengths) - len(seq_lens_list)) else: @@ -427,7 +430,7 @@ class GraphParams: _graph_params: Optional[GraphParams] = None -def set_graph_params(aclgraph_capture_sizes: set[int]): +def set_graph_params(aclgraph_capture_sizes: list[int]): global _graph_params if _graph_params is not None: raise ValueError("Graph parameters have already been set!") @@ -456,7 +459,7 @@ def get_graph_params(): _mtp_graph_params: Optional[GraphParams] = None -def set_mtp_graph_params(aclgraph_capture_sizes: set[int]): +def set_mtp_graph_params(aclgraph_capture_sizes: list[int]): global _mtp_graph_params if _mtp_graph_params is not None: raise ValueError("MTPGraph parameters have already been set!") diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index eb71bfb2..8ff325ff 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -32,7 +32,6 @@ from vllm_ascend.ascend_forward_context import (MoECommType, from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, - set_mtp_graph_params, update_mla_attn_params) from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable, @@ -214,8 +213,6 @@ class MtpProposer(Proposer): if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( ): self.update_stream: torch.npu.Stream = torch.npu.Stream() - set_mtp_graph_params( - self.vllm_config.compilation_config.cudagraph_capture_sizes) self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) @@ -254,9 +251,10 @@ class MtpProposer(Proposer): query_start_loc_cpu=self.runner. query_start_loc_cpu[:num_reqs + 1], seq_lens_cpu=self.runner.seq_lens_cpu, - seq_lens=self.runner.seq_lens_cpu[:num_reqs], + seq_lens=self.runner.seq_lens[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, + num_input_tokens=num_tokens, max_query_len=self.num_speculative_tokens + 1, num_computed_tokens_cpu=num_computed_tokens_cpu, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, @@ -289,7 +287,7 @@ class MtpProposer(Proposer): positions = self.positions[:num_tokens] previous_hidden_states = self.hidden_states[:num_tokens] for i in range(self.num_speculative_tokens): - if i > 0: + if i > 0 and not skip_attn and aclgraph_runtime_mode == CUDAGraphMode.FULL: aclgraph_runtime_mode = CUDAGraphMode.NONE with set_ascend_forward_context( attn_metadata, @@ -316,7 +314,7 @@ class MtpProposer(Proposer): forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \ not forward_context.capturing: - if self.vllm_config.model_config.use_mla: + if self.vllm_config.model_config.use_mla and not self.use_sparse: update_mla_attn_params( self.update_stream, forward_context, num_tokens, self.vllm_config.speculative_config) @@ -514,6 +512,7 @@ class MtpProposer(Proposer): # q1, q1 + 1, ..., q1 + q2 - n2 - 1, # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + num_actual_reqs = len(num_draft_tokens) num_rejected_tokens = [ n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) @@ -522,8 +521,11 @@ class MtpProposer(Proposer): dtype=torch.int32) device = common_attn_metadata.query_start_loc.device - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_actual_reqs + + 1] + seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs] + new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] new_query_len_per_req = query_start_loc_cpu[ @@ -587,6 +589,7 @@ class MtpProposer(Proposer): num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, + num_input_tokens=common_attn_metadata.num_input_tokens, max_query_len=new_query_len_per_req.max().item(), block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, @@ -704,8 +707,8 @@ class MtpProposer(Proposer): assert self.runner is not None - if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( - ) and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]: + if self.runner.use_aclgraph and num_scheduled_tokens <= self.cudagraph_batch_sizes[ + -1]: num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) elif self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]: @@ -797,7 +800,7 @@ class MtpProposer(Proposer): hidden_states=hidden_states) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - if self.vllm_config.model_config.use_mla: + if self.vllm_config.model_config.use_mla and not self.use_sparse: update_mla_attn_params( self.update_stream, forward_context, num_input_tokens, @@ -1109,9 +1112,10 @@ class MtpProposer(Proposer): spec_common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=common_attn_metadata.query_start_loc, query_start_loc_cpu=query_start_loc_cpu, - seq_lens_cpu=common_attn_metadata.seq_lens, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, + num_input_tokens=common_attn_metadata.num_input_tokens, max_query_len=new_query_len_per_req.max().item(), actual_seq_lengths_q=self.runner.actual_seq_lengths_q, block_table_tensor=common_attn_metadata.block_table_tensor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 715b3937..b9e23334 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -124,6 +124,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, # yapf: disable from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, set_graph_params, + set_mtp_graph_params, update_attn_dcp_pcp_params, update_attn_params, update_mla_attn_dcp_pcp_params, @@ -406,8 +407,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): dtype=torch.int32, device=self.device) - if self.vllm_config.model_config.use_mla and \ - self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + # NOTE: This will have some extra memory allocated, is it OK? + if self.vllm_config.model_config.use_mla: rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos = torch.ones(self.max_num_reqs * self.decode_token_per_req, @@ -1843,6 +1844,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + # NOTE: This is strange, why did we use total_num_scheduled_tokens before? slot_mapping_size = (total_num_scheduled_tokens if self.pcp_size == 1 else total_num_scheduled_tokens * self.pcp_size - @@ -1864,7 +1866,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): else: blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor() - slot_mapping = blk_table.slot_mapping[:slot_mapping_size] blk_table.slot_mapping[slot_mapping_size:].fill_(0) if self.pcp_size > 1: slot_mapping_for_pcp = blk_table.slot_mapping[: @@ -1884,14 +1885,48 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): slot_mapping_size] slot_mapping_for_pcp[:long_seq_metadata. num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping - slot_mapping = slot_mapping_for_pcp + blk_table.slot_mapping[:long_seq_metadata.num_actual_tokens_pcp_padded] = \ + slot_mapping_for_pcp + slot_mapping = blk_table.slot_mapping + + # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs + # has been split to multiple parts, and there are 3 parts that is related to this + # `num_reqs`, we'll take `query_start_loc` as an example: + # 1. self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens + # 2. get `num_reqs_padded`, this depends on dispatcher and which is why we have the + # following simplified `dispatch` logic here, we try to minimize the impact + # 3. query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1] + uniform_decode = (max_num_scheduled_tokens == self.uniform_decode_query_len) \ + and (total_num_scheduled_tokens == max_num_scheduled_tokens * num_reqs) + + # TODO: We should make this official ASAP. Also note that if we pad here, + # the builders won’t need to add any extra padding. + if self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + uniform_decode: + num_reqs_padded = num_input_tokens // self.uniform_decode_query_len + pad_size = num_reqs_padded - num_reqs + if pad_size > 0: + last_query_loc = self.query_start_loc[num_reqs] + + steps = torch.arange(1, + pad_size + 1, + device=self.device, + dtype=self.query_start_loc.dtype) + fill_values = last_query_loc + ( + steps * self.uniform_decode_query_len) + + self.query_start_loc[num_reqs + 1:num_reqs_padded + + 1] = fill_values + # So we are trying to simulate the behavior of GPUModelRunner's + # prepare_inputs for uniform decode mode by padding query_start_loc + num_reqs = num_reqs_padded # Make AscendCommonAttentionMetadata common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - seq_lens=self.seq_lens_cpu[:num_reqs], + seq_lens=self.seq_lens[:num_reqs], num_reqs=num_reqs, num_actual_tokens=slot_mapping_size, num_input_tokens=num_input_tokens, @@ -2876,6 +2911,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): seq_lens = max_query_len self.seq_lens_np[:num_reqs] = seq_lens self.seq_lens_np[num_reqs:] = 0 + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) cu_num_tokens, arange = self._get_cumsum_and_arange( num_scheduled_tokens) @@ -2906,21 +2943,22 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): [0] * dcp_world_size for _ in range(pcp_world_size) ] for _ in range(num_tokens)] long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp + # QUESTION: Why do we separately set query_start_loc for spec in the first place? + # While in _prepare_inputs we don't? if self.speculative_config: - query_start_loc = torch.tensor( + self.query_start_loc[:num_reqs + 1] = torch.tensor( [0] + self.actual_seq_lengths_q[:num_reqs], device=self.device, dtype=torch.int32) - else: - query_start_loc = self.query_start_loc[:num_reqs + 1] common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=query_start_loc, + query_start_loc=self.query_start_loc[:num_reqs + 1], query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], seq_lens_cpu=self.seq_lens_cpu, - seq_lens=self.seq_lens_cpu[:num_reqs], + seq_lens=self.seq_lens[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, + num_input_tokens=num_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, block_table_tensor=block_table_tensor[:num_reqs], slot_mapping=slot_mapping, @@ -3210,7 +3248,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): num_tokens_across_dp=num_tokens_across_dp, aclgraph_runtime_mode=aclgraph_runtime_mode, batch_descriptor=batch_descriptor, - dummy_compute_logits=dummy_drafter_compute_logits) + dummy_compute_logits=dummy_drafter_compute_logits, + skip_attn=not force_attention) if self.in_profile_run and self.dynamic_eplb: self.model.clear_all_moe_loads() if not self.in_profile_run and self.dynamic_eplb: @@ -3373,7 +3412,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): # wrap the model with full graph wrapper if needed. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.update_stream: torch.npu.Stream = torch.npu.Stream() - set_graph_params(self.compilation_config.cudagraph_capture_sizes) self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) @@ -4092,6 +4130,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.aclgraph_batch_sizes = (capture_sizes if capture_sizes is not None else []) + # NOTE: Since aclgraph_batch_sizes cannot be determined until here, + # we set the graph params right before initializing the keys. + set_graph_params(self.aclgraph_batch_sizes) + if self.speculative_config: + set_mtp_graph_params(self.aclgraph_batch_sizes) + self.aclgraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, self.uniform_decode_query_len)