[Feat][Graph]Support FULL_DECEDE_ONLY mode for MLA models (#3125)
### What this PR does / why we need it?
Adds support for capturing the Multi-Layer Attention (MLA) decode
operation into an ACL graph. This improves performance by compiling the
attention kernel for single-token decoding.
Key changes include:
- Implementing the graph capture logic for the MLA kernel, including
workspace management and parameter updates.
- Modifying the rotary embedding (RoPE) handling to use pre-allocated
tensors, which is a requirement for graph capture.
- Adding a `build_for_graph_capture` method to the MLA metadata builder
to create dummy metadata during the graph compilation phase.
Known issues:
- Currently, MTP is not supported in FULL_DECEDE_ONLY mode -- we're
working on a fix
- We are preparing to remove update_mla_attn_params with
auto_dispatch_capture
### Does this PR introduce _any_ user-facing change?
compilation_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
### How was this patch tested?
- vLLM version: v0.11.0
---------
Signed-off-by: panchao-hub <315134829@qq.com>
Signed-off-by: p00465316 <panchao13@huawei.com>
Co-authored-by: p00465316 <panchao13@huawei.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
MLAAttentionImpl)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
@@ -21,6 +22,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
@@ -169,7 +171,7 @@ M = TypeVar("M", bound=AscendMLAMetadata)
|
||||
class AscendMLAMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@@ -389,6 +391,8 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
cos = common_attn_metadata.cos
|
||||
sin = common_attn_metadata.sin
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
@@ -397,21 +401,45 @@ class AscendMLAMetadataBuilder:
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||
assert self.cos_cache is not None
|
||||
assert self.sin_cache is not None
|
||||
if cos is None and sin is None:
|
||||
cos = self.cos_cache[
|
||||
input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[
|
||||
input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
else:
|
||||
cos[:num_decodes,
|
||||
...] = self.cos_cache[input_positions].unsqueeze(
|
||||
1).unsqueeze(2)
|
||||
sin[:num_decodes,
|
||||
...] = self.sin_cache[input_positions].unsqueeze(
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin[:num_decodes, ...],
|
||||
cos=cos[:num_decodes, ...])
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -431,6 +459,26 @@ class AscendMLAMetadataBuilder:
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
def build_for_graph_capture(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
|
||||
model: Optional[nn.Module] = None,
|
||||
):
|
||||
if attn_state == AscendAttentionState.DecodeOnly:
|
||||
attn_metadata = self.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=model,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently we only support building dummy metadata for DecodeOnly state"
|
||||
)
|
||||
|
||||
attn_metadata.attn_state = attn_state
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class DecodeMLAPreprocessResult(NamedTuple):
|
||||
ql_nope: Optional[torch.Tensor] = None
|
||||
@@ -834,24 +882,63 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout=input_layout,
|
||||
atten_mask=spec_attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
block_table=decode_meta.block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
||||
actual_seq_lengths=actual_seq_lengths)
|
||||
common_kwargs = {
|
||||
'query_rope': q_pe,
|
||||
'key_rope': k_pe,
|
||||
'num_heads': self.num_heads,
|
||||
'num_key_value_heads': self.num_kv_heads,
|
||||
'input_layout': input_layout,
|
||||
'atten_mask': spec_attn_mask,
|
||||
'sparse_mode': sparse_mode,
|
||||
'scale': self.scale,
|
||||
'antiquant_mode': 0,
|
||||
'antiquant_scale': None,
|
||||
'block_table': decode_meta.block_table,
|
||||
'block_size': block_size,
|
||||
"actual_seq_lengths": actual_seq_lengths,
|
||||
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
graph_params.workspaces[num_tokens] = workspace
|
||||
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty(num_tokens,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(q_nope, k_nope, q_pe, k_pe, self.num_heads, self.num_kv_heads,
|
||||
input_layout, spec_attn_mask, sparse_mode, self.scale,
|
||||
decode_meta.block_table, block_size,
|
||||
decode_meta.seq_lens_list, actual_seq_lengths, workspace,
|
||||
attn_output, softmax_lse))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
**common_kwargs,
|
||||
workspace=workspace,
|
||||
out=[attn_output, softmax_lse])
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope, k_nope, k_nope, **common_kwargs)
|
||||
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
|
||||
Reference in New Issue
Block a user