[Feat]Make full graph mode compalible with MTP (#3276)
### What this PR does / why we need it? Make the Full Graph mode can run with MTP. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -175,7 +175,7 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
class AscendMLAMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@@ -209,7 +209,6 @@ class AscendMLAMetadataBuilder:
|
||||
got {self.decode_threshold}"
|
||||
|
||||
self.reorder_batch_threshold = self.decode_threshold
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
@@ -427,10 +426,10 @@ class AscendMLAMetadataBuilder:
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
else:
|
||||
cos[:num_decodes,
|
||||
cos[:num_decode_tokens,
|
||||
...] = self.cos_cache[input_positions].unsqueeze(
|
||||
1).unsqueeze(2)
|
||||
sin[:num_decodes,
|
||||
sin[:num_decode_tokens,
|
||||
...] = self.sin_cache[input_positions].unsqueeze(
|
||||
1).unsqueeze(2)
|
||||
|
||||
@@ -442,8 +441,8 @@ class AscendMLAMetadataBuilder:
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:num_decodes, ...],
|
||||
cos=cos[:num_decodes, ...])
|
||||
sin=sin[:num_decode_tokens, ...],
|
||||
cos=cos[:num_decode_tokens, ...])
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -469,7 +468,10 @@ class AscendMLAMetadataBuilder:
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state == AscendAttentionState.DecodeOnly:
|
||||
if attn_state in {
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
}:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
@@ -477,7 +479,7 @@ class AscendMLAMetadataBuilder:
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently we only support building dummy metadata for DecodeOnly state"
|
||||
"Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state"
|
||||
)
|
||||
|
||||
attn_metadata.attn_state = attn_state
|
||||
@@ -955,7 +957,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.SpecDecoding,
|
||||
AscendAttentionState.ChunkedPrefill
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.DecodeOnly,
|
||||
] and self.speculative_config is not None:
|
||||
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
|
||||
input_layout = "TND"
|
||||
|
||||
Reference in New Issue
Block a user