[Bugfix] Fix specdecoding in chunkedprefill scenario (#3025)

### What this PR does / why we need it?

The speculative decode phase of chunkedprefill has taken an incorrect
path, should always use TND layout for speculative decoding.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.10.2
- vLLM main:
6d8246aaff

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-09-19 14:05:08 +08:00
committed by GitHub
parent 833cd1b698
commit 2a87b4cecb
2 changed files with 12 additions and 12 deletions

View File

@@ -495,11 +495,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.ring_mla_mask_size = 512
self.prefill_mask = None
# Adapt torch air graph mode with spec decoding.
speculative_config = vllm_config.speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.speculative_config = vllm_config.speculative_config
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
@@ -811,7 +807,11 @@ class AscendMLAImpl(MLAAttentionImpl):
self.qk_rope_head_dim)
input_layout = "BNSD"
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill
] and self.speculative_config is not None:
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1)

View File

@@ -676,11 +676,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
self.prefill_mask = None
self.ring_mla_mask_size = 512
# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.speculative_config = get_current_vllm_config().speculative_config
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
# Convert from (B, N, L) to (N, B, L)
@@ -1012,7 +1008,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
self.qk_rope_head_dim)
input_layout = "BNSD"
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill
] and self.speculative_config is not None:
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1)