[Feat]Make full graph mode compalible with MTP (#3276)

### What this PR does / why we need it?
Make the Full Graph mode can run with MTP.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-17 20:19:56 +08:00
committed by GitHub
parent 46e62efd44
commit 248ee7fa11
7 changed files with 103 additions and 44 deletions

View File

@@ -175,7 +175,7 @@ M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMLAMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
AttentionCGSupport.UNIFORM_BATCH
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@@ -209,7 +209,6 @@ class AscendMLAMetadataBuilder:
got {self.decode_threshold}"
self.reorder_batch_threshold = self.decode_threshold
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
@@ -427,10 +426,10 @@ class AscendMLAMetadataBuilder:
sin=sin,
cos=cos)
else:
cos[:num_decodes,
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
1).unsqueeze(2)
sin[:num_decodes,
sin[:num_decode_tokens,
...] = self.sin_cache[input_positions].unsqueeze(
1).unsqueeze(2)
@@ -442,8 +441,8 @@ class AscendMLAMetadataBuilder:
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decodes, ...],
cos=cos[:num_decodes, ...])
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...])
return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
@@ -469,7 +468,10 @@ class AscendMLAMetadataBuilder:
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
model: Optional[nn.Module] = None,
):
if attn_state == AscendAttentionState.DecodeOnly:
if attn_state in {
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
}:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
@@ -477,7 +479,7 @@ class AscendMLAMetadataBuilder:
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state"
)
attn_metadata.attn_state = attn_state
@@ -955,7 +957,8 @@ class AscendMLAImpl(MLAAttentionImpl):
if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
] and self.speculative_config is not None:
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
input_layout = "TND"