[Bugfix] Fix specdecoding in chunkedprefill scenario (#3025)
### What this PR does / why we need it?
The speculative decode phase of chunkedprefill has taken an incorrect
path, should always use TND layout for speculative decoding.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
6d8246aaff
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -495,11 +495,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.ring_mla_mask_size = 512
|
||||
self.prefill_mask = None
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = vllm_config.speculative_config
|
||||
if speculative_config is not None:
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
@@ -811,7 +807,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.ChunkedPrefill
|
||||
] and self.speculative_config is not None:
|
||||
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
|
||||
input_layout = "TND"
|
||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||
|
||||
@@ -676,11 +676,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
self.prefill_mask = None
|
||||
self.ring_mla_mask_size = 512
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = get_current_vllm_config().speculative_config
|
||||
if speculative_config is not None:
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
self.speculative_config = get_current_vllm_config().speculative_config
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
@@ -1012,7 +1008,11 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.ChunkedPrefill
|
||||
] and self.speculative_config is not None:
|
||||
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
|
||||
input_layout = "TND"
|
||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||
|
||||
Reference in New Issue
Block a user