support cp&dcp (#3260)
### What this PR does / why we need it? This PR adds the Prefill Context Parallelism (PCP) feature, which corresponds to DCP. For specific implementation details, please refer to the RFC https://github.com/vllm-project/vllm/issues/25749. TL;DR: PCP enhances long-sequence inference capabilities by partitioning the sequence dimension during the prefill stage. ### Does this PR introduce _any_ user-facing change? The current implementation primarily includes the following changes: Modified ModelRunner.py for CP partitioning logic for tokens; Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to PCP. Modified block_tables.py to extend the KV cache storage based on DCP&PCP; Added necessary command-line arguments to control parallelism for PCP; ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: chenjie <chenjie137@huawei.com> Signed-off-by: Delphine-Nic <tanwenqin@huawei.com> Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: Feng Liu <liufeng248@huawei.com> Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: weiguihua2 <weiguihua2@huawei.com> Signed-off-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: chenjie <chenjie137@huawei.com> Co-authored-by: Delphine-Nic <tanwenqin@huawei.com> Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com> Co-authored-by: Feng Liu <liufeng248@huawei.com> Co-authored-by: gaojc <1055866782@qq.com> Co-authored-by: weiguihua2 <weiguihua2@huawei.com> Co-authored-by: z50049692 <zhangmingwei11@huawei.com> Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
@@ -19,29 +19,45 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import ClassVar, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_dcp_group,
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec, version_check)
|
||||
nd_to_nz_2d, nd_to_nz_spec,
|
||||
prefill_context_parallel_enable, version_check)
|
||||
|
||||
from ..utils import weak_ref_tensors
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import (get_pcp_group,
|
||||
get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
# isort:on
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
@@ -127,15 +143,47 @@ class AscendAttentionState(Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata:
|
||||
class AscendPCPMetadata:
|
||||
q_head_idx: torch.Tensor = None
|
||||
q_tail_idx: torch.Tensor = None
|
||||
kv_with_q_head_nomask_idx: torch.Tensor = None
|
||||
kv_with_q_head_mask_idx: torch.Tensor = None
|
||||
kv_with_q_tail_nomask_idx: torch.Tensor = None
|
||||
kv_with_q_tail_mask_idx: torch.Tensor = None
|
||||
attn_mask_seqlens: torch.Tensor = None
|
||||
head_attn_nomask_seqlens: torch.Tensor = None
|
||||
tail_attn_nomask_seqlens: torch.Tensor = None
|
||||
q_full_idx: torch.Tensor = None
|
||||
pcp_prefill_mask: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForPrefill:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
pcp_allgather_restore_idx: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForDecode:
|
||||
""" Decode Specific Metadata for Ascend"""
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata:
|
||||
# **************************** Basic Properties ************************** #
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
# Number of tokens excluding padding.
|
||||
num_actual_tokens_pcp_padded: int = 0
|
||||
num_actual_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
# The sequence length per sequence. Sequence length means the computed
|
||||
# tokens + new tokens (is None if it is a decoding).
|
||||
@@ -168,6 +216,10 @@ class AscendMetadata:
|
||||
# *************************** Other Properties *************************** #
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
prefill: Optional[AscendMetadataForPrefill] = None
|
||||
|
||||
decode_meta: Optional[AscendMetadataForDecode] = None
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
@@ -207,10 +259,25 @@ class AscendAttentionMetadataBuilder:
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
|
||||
decode_threshold = 1
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
||||
if num_actual_tokens_pcp_padded is None:
|
||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
@@ -218,7 +285,7 @@ class AscendAttentionMetadataBuilder:
|
||||
+ 1]
|
||||
|
||||
if attn_state == AscendAttentionState.DecodeOnly and \
|
||||
common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||
common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
|
||||
seq_lens = torch.cat([
|
||||
seq_lens,
|
||||
@@ -252,8 +319,51 @@ class AscendAttentionMetadataBuilder:
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
pcp_metadata = None
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
if common_long_seq_metadata is not None:
|
||||
pcp_metadata = AscendPCPMetadata(
|
||||
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||
kv_with_q_head_nomask_idx=common_long_seq_metadata.
|
||||
kv_with_q_head_nomask_idx_tensor,
|
||||
kv_with_q_head_mask_idx=common_long_seq_metadata.
|
||||
kv_with_q_head_mask_idx_tensor,
|
||||
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
|
||||
kv_with_q_tail_nomask_idx_tensor,
|
||||
kv_with_q_tail_mask_idx=common_long_seq_metadata.
|
||||
kv_with_q_tail_mask_idx_tensor,
|
||||
attn_mask_seqlens=common_long_seq_metadata.
|
||||
attn_mask_seqlens,
|
||||
head_attn_nomask_seqlens=common_long_seq_metadata.
|
||||
head_attn_nomask_seqlens,
|
||||
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
||||
tail_attn_nomask_seqlens,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
|
||||
prefill_metadata = AscendMetadataForPrefill(
|
||||
pcp_metadata=pcp_metadata,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||
pcp_allgather_restore_idx
|
||||
if common_long_seq_metadata is not None else None)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
if common_long_seq_metadata is not None:
|
||||
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||
num_computed_tokens_of_pcp_dcp = np.array(
|
||||
num_computed_tokens_of_pcp_dcp)
|
||||
decode_metadata = AscendMetadataForDecode(
|
||||
num_computed_tokens_of_pcp_dcp=
|
||||
num_computed_tokens_of_pcp_dcp)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
@@ -264,7 +374,11 @@ class AscendAttentionMetadataBuilder:
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
prefill=prefill_metadata,
|
||||
decode_meta=decode_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_graph_capture(
|
||||
@@ -322,6 +436,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
self.torch_npu_check = version_check()
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.pcp_group = get_pcp_group(
|
||||
).device_group if self.pcp_size > 1 else None
|
||||
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
self.dcp_group = get_dcp_group(
|
||||
).device_group if self.dcp_size > 1 else None
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
@@ -581,6 +707,236 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
|
||||
lengths: List[int]) -> torch.Tensor:
|
||||
max_len = max(lengths)
|
||||
splits = torch.split(tensor_tnd, lengths, dim=0)
|
||||
|
||||
padded = []
|
||||
for s in splits:
|
||||
pad_len = max_len - s.shape[0]
|
||||
s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len))
|
||||
padded.append(s_pad)
|
||||
|
||||
tensor_bsnd = torch.stack(padded, dim=0)
|
||||
return tensor_bsnd
|
||||
|
||||
def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor,
|
||||
lengths: List[int]) -> torch.Tensor:
|
||||
slices = []
|
||||
for i, length in enumerate(lengths):
|
||||
slices.append(tensor_bsnd[i, :length])
|
||||
tensor_tnd = torch.cat(slices, dim=0)
|
||||
return tensor_tnd
|
||||
|
||||
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
||||
q_seqlens: List[int],
|
||||
k_nomask: torch.Tensor,
|
||||
v_nomask: torch.Tensor,
|
||||
kv_seqlens_nomask: List[int],
|
||||
k_mask: torch.Tensor,
|
||||
v_mask: torch.Tensor,
|
||||
kv_seqlens_mask: List[int],
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
q = self._pack_tnd_2_bsnd(q, q_seqlens)
|
||||
|
||||
# nomask Attention
|
||||
if k_nomask is not None:
|
||||
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q,
|
||||
self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask),
|
||||
self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask),
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=kv_seqlens_nomask,
|
||||
actual_seq_lengths=q_seqlens)
|
||||
attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask,
|
||||
q_seqlens)
|
||||
# (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
|
||||
attn_lse_nomask = self._unpack_bsnd_2_tnd(
|
||||
attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens)
|
||||
|
||||
# mask Attention
|
||||
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
q,
|
||||
self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask),
|
||||
self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask),
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSND",
|
||||
atten_mask=mask,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=kv_seqlens_mask,
|
||||
actual_seq_lengths=q_seqlens)
|
||||
attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens)
|
||||
attn_lse_mask = self._unpack_bsnd_2_tnd(
|
||||
attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens)
|
||||
|
||||
# update
|
||||
output = attn_out_mask
|
||||
if k_nomask is not None:
|
||||
output, _ = self._update_out_and_lse(
|
||||
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
|
||||
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
|
||||
|
||||
return output
|
||||
|
||||
def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.prefill is not None
|
||||
assert attn_metadata.prefill.pcp_metadata is not None
|
||||
# Use precomputed indices from the metadata (already converted to tensors and on device)
|
||||
q_head_idx = attn_metadata.prefill.pcp_metadata.q_head_idx
|
||||
q_tail_idx = attn_metadata.prefill.pcp_metadata.q_tail_idx
|
||||
kv_with_q_head_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_nomask_idx
|
||||
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
|
||||
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
|
||||
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
|
||||
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
|
||||
# 1. Attention calculation in the first half of Q in load balancing
|
||||
output_head = self._attention_with_nomask_and_mask(
|
||||
q=torch.index_select(query, 0, q_head_idx),
|
||||
q_seqlens=attn_mask_seqlens[0].tolist(),
|
||||
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
|
||||
if self.pcp_rank > 0 else None,
|
||||
v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx)
|
||||
if self.pcp_rank > 0 else None,
|
||||
kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(),
|
||||
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
|
||||
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
|
||||
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
|
||||
mask=mask)
|
||||
|
||||
# 2. the Attention calculation in the latter half of Q in load balancing
|
||||
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
|
||||
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
|
||||
output_tail = self._attention_with_nomask_and_mask(
|
||||
q=torch.index_select(query, 0, q_tail_idx),
|
||||
q_seqlens=attn_mask_seqlens[0].tolist(),
|
||||
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
|
||||
v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx),
|
||||
kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(),
|
||||
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
|
||||
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
|
||||
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
|
||||
mask=mask)
|
||||
|
||||
# 3. Combine the output of the first half and second half.
|
||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||
output = torch.index_select(
|
||||
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
|
||||
return output
|
||||
|
||||
def _update_out_and_lse(self, out_list: torch.Tensor,
|
||||
lse_list: torch.Tensor) -> torch.Tensor:
|
||||
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
|
||||
Args:
|
||||
out_list: shape = [N, batch_size, num_heads, head_size]
|
||||
lse_list: shape = [N, batch_size, num_heads, 1]
|
||||
Returns:
|
||||
out_final: shape = [batch_size, num_heads, head_size]
|
||||
lse_final: shape = [batch_size, num_heads, 1]
|
||||
"""
|
||||
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
|
||||
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
|
||||
dim=0)
|
||||
return out_final, lse_final
|
||||
|
||||
def _forward_decode_pcp_dcp(self, query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata) -> torch.Tensor:
|
||||
assert self.key_cache is not None
|
||||
assert self.value_cache is not None
|
||||
|
||||
if self.dcp_size > 1:
|
||||
query = get_dcp_group().all_gather(query, 1)
|
||||
num_heads = self.num_heads * self.dcp_size
|
||||
else:
|
||||
num_heads = self.num_heads
|
||||
|
||||
# 1. Compute out&lse by "npu_fused_infer_attention_score"
|
||||
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
|
||||
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
|
||||
self.key_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1),
|
||||
self.value_cache.view(self.key_cache.shape[0],
|
||||
self.key_cache.shape[1], -1),
|
||||
num_heads=num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BSND",
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
block_table=attn_metadata.block_tables,
|
||||
block_size=self.key_cache.shape[1],
|
||||
actual_seq_lengths_kv=attn_metadata.decode_meta.
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
|
||||
)
|
||||
|
||||
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
|
||||
attn_out.shape[3])
|
||||
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
|
||||
if self.dcp_size > 1:
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
|
||||
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
|
||||
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
|
||||
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
|
||||
dist.all_to_all_single(attn_out_lse_all2all,
|
||||
attn_out_lse,
|
||||
group=self.dcp_group)
|
||||
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
|
||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
||||
attn_out_lse_split_on_seq = list(
|
||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
||||
|
||||
attn_out_lse_split_dcp = torch.stack(
|
||||
attn_out_lse_split_on_seq,
|
||||
dim=0) # [dcp, batch_size, num_heads, head_size+1]
|
||||
# Update out&lse
|
||||
attn_out_split_dcp, attn_lse_split_dcp = torch.split(
|
||||
attn_out_lse_split_dcp, [self.head_size, 1], dim=-1)
|
||||
attn_out, attn_lse = self._update_out_and_lse(
|
||||
attn_out_split_dcp, attn_lse_split_dcp)
|
||||
if self.pcp_size > 1:
|
||||
# 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
|
||||
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
|
||||
# 3. AllGather out&lse within CP group
|
||||
attn_out_lse_list = [
|
||||
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
|
||||
]
|
||||
dist.all_gather(attn_out_lse_list,
|
||||
attn_out_lse,
|
||||
group=self.pcp_group)
|
||||
# 4. Update out&lse
|
||||
attn_out_lse_allgather = torch.stack(
|
||||
attn_out_lse_list,
|
||||
dim=0) # [pcp, batch_size, num_heads, head_size+1]
|
||||
attn_out_allgather, attn_lse_allgather = torch.split(
|
||||
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
|
||||
attn_out, _ = self._update_out_and_lse(attn_out_allgather,
|
||||
attn_lse_allgather)
|
||||
return attn_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@@ -633,7 +989,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
else:
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
|
||||
@@ -650,14 +1009,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
if len(kv_cache) > 1:
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
|
||||
if has_decode:
|
||||
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
|
||||
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:num_decode_tokens],
|
||||
value=value[:num_decode_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slot_mapping)
|
||||
|
||||
if has_prefill:
|
||||
if self.pcp_size > 1:
|
||||
kv = torch.cat([key, value], dim=-1)
|
||||
all_kv = get_pcp_group().all_gather(kv, dim=0)
|
||||
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
|
||||
all_kv = torch.index_select(all_kv, 0,
|
||||
pcp_allgather_restore_idx)
|
||||
key, value = all_kv.split(
|
||||
[self.head_size, self.head_size], dim=-1)
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[self.pcp_size *
|
||||
num_decode_tokens:attn_metadata.
|
||||
num_actual_tokens_pcp_padded],
|
||||
value=value[self.pcp_size *
|
||||
num_decode_tokens:attn_metadata.
|
||||
num_actual_tokens_pcp_padded],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=attn_metadata.
|
||||
slot_mapping[self.pcp_size *
|
||||
num_decode_tokens:attn_metadata.
|
||||
num_actual_tokens_pcp_padded])
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
output = self._forward_pcp_dcp(query, key, value,
|
||||
attn_metadata, output)
|
||||
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
attn_out = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
@@ -668,7 +1059,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
scale=self.scale,
|
||||
sparse_mode=4,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=attn_metadata.max_query_len,
|
||||
next_tockens=attn_metadata.max_query_len,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
@@ -679,7 +1069,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
@@ -701,6 +1091,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
if output is None:
|
||||
raise ValueError("Output buffer is required")
|
||||
if has_decode:
|
||||
decode_query = query[:num_decode_tokens]
|
||||
output_decode = self._forward_decode_pcp_dcp(
|
||||
decode_query, attn_metadata)
|
||||
output[:num_decode_tokens] = output_decode
|
||||
if has_prefill:
|
||||
prefill_query = query[num_decode_tokens:]
|
||||
key = key[self.pcp_size * num_decode_tokens:]
|
||||
value = value[self.pcp_size * num_decode_tokens:]
|
||||
if self.pcp_size > 1:
|
||||
output_prefill = self._forward_prefill_cp(
|
||||
prefill_query, key, value, attn_metadata)
|
||||
else:
|
||||
max_prefill_seq_len = attn_metadata.seq_lens[
|
||||
attn_metadata.num_decode_tokens:].max().item()
|
||||
if attn_metadata.attn_mask is not None:
|
||||
attn_metadata.attn_mask = attn_metadata.attn_mask[:
|
||||
max_prefill_seq_len, :
|
||||
max_prefill_seq_len]
|
||||
else:
|
||||
ValueError("Attn_metadata.attn_mask is required")
|
||||
seq_lens_back = attn_metadata.seq_lens
|
||||
attn_metadata.seq_lens = attn_metadata.seq_lens[
|
||||
attn_metadata.num_decode_tokens:]
|
||||
output_prefill = self._forward_prefill_no_cache(
|
||||
prefill_query, key, value, attn_metadata,
|
||||
output[num_decode_tokens:], prefill_query.shape[0])
|
||||
attn_metadata.seq_lens = seq_lens_back
|
||||
output[num_decode_tokens:] = output_prefill
|
||||
return output
|
||||
|
||||
|
||||
def unified_ascend_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user