diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 6ee943d..aa2c818 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -495,11 +495,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.ring_mla_mask_size = 512 self.prefill_mask = None - # Adapt torch air graph mode with spec decoding. - speculative_config = vllm_config.speculative_config - if speculative_config is not None: - self.spec_token_num = speculative_config.num_speculative_tokens - assert self.spec_token_num > 0 + self.speculative_config = vllm_config.speculative_config def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) @@ -811,7 +807,11 @@ class AscendMLAImpl(MLAAttentionImpl): self.qk_rope_head_dim) input_layout = "BNSD" - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + if attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill + ] and self.speculative_config is not None: + # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1) diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index c4b9ac2..3f54fdb 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -676,11 +676,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.prefill_mask = None self.ring_mla_mask_size = 512 - # Adapt torch air graph mode with spec decoding. - speculative_config = get_current_vllm_config().speculative_config - if speculative_config is not None: - self.spec_token_num = speculative_config.num_speculative_tokens - assert self.spec_token_num > 0 + self.speculative_config = get_current_vllm_config().speculative_config def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): # Convert from (B, N, L) to (N, B, L) @@ -1012,7 +1008,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.qk_rope_head_dim) input_layout = "BNSD" - if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: + if attn_metadata.attn_state in [ + AscendAttentionState.SpecDecoding, + AscendAttentionState.ChunkedPrefill + ] and self.speculative_config is not None: + # Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1)