support cp&dcp (#3260)

### What this PR does / why we need it?
This PR adds the Prefill Context Parallelism (PCP) feature, which
corresponds to DCP. For specific implementation details, please refer to
the RFC https://github.com/vllm-project/vllm/issues/25749.
TL;DR: PCP enhances long-sequence inference capabilities by partitioning
the sequence dimension during the prefill stage.
### Does this PR introduce _any_ user-facing change?
The current implementation primarily includes the following changes:

Modified ModelRunner.py for CP partitioning logic for tokens;
Modified attention_v1.py and mla_v1.py to adapt the GQA/MLA backend to
PCP.
Modified block_tables.py to extend the KV cache storage based on
DCP&PCP;
Added necessary command-line arguments to control parallelism for PCP;
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: chenjie <chenjie137@huawei.com>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Signed-off-by: Feng Liu <liufeng248@huawei.com>
Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Signed-off-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: chenjie <chenjie137@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Co-authored-by: Feng Liu <liufeng248@huawei.com>
Co-authored-by: gaojc <1055866782@qq.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: z50049692 <zhangmingwei11@huawei.com>
Co-authored-by: w00896881 <wangzixuan40@huawei.com>
This commit is contained in:
LookAround0301
2025-10-24 10:32:01 +08:00
committed by GitHub
parent 2bcadcb9d5
commit b54d44e664
18 changed files with 1729 additions and 211 deletions

View File

@@ -19,29 +19,45 @@ from dataclasses import dataclass
from enum import Enum
from typing import ClassVar, List, Optional, Tuple, Type
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.config import VllmConfig
from vllm.distributed import (get_dcp_group,
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
# isort: off
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec, version_check)
nd_to_nz_2d, nd_to_nz_spec,
prefill_context_parallel_enable, version_check)
from ..utils import weak_ref_tensors
if prefill_context_parallel_enable():
from vllm.distributed import (get_pcp_group,
get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort:on
class AscendAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -127,15 +143,47 @@ class AscendAttentionState(Enum):
@dataclass
class AscendMetadata:
class AscendPCPMetadata:
q_head_idx: torch.Tensor = None
q_tail_idx: torch.Tensor = None
kv_with_q_head_nomask_idx: torch.Tensor = None
kv_with_q_head_mask_idx: torch.Tensor = None
kv_with_q_tail_nomask_idx: torch.Tensor = None
kv_with_q_tail_mask_idx: torch.Tensor = None
attn_mask_seqlens: torch.Tensor = None
head_attn_nomask_seqlens: torch.Tensor = None
tail_attn_nomask_seqlens: torch.Tensor = None
q_full_idx: torch.Tensor = None
pcp_prefill_mask: torch.Tensor = None
@dataclass
class AscendMetadataForPrefill:
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None
pcp_allgather_restore_idx: Optional[List[int]] = None
@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
list[int]]]]]] = None
@dataclass
class AscendMetadata:
# **************************** Basic Properties ************************** #
attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
# Number of tokens excluding padding.
num_actual_tokens_pcp_padded: int = 0
num_actual_tokens: int = 0
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
# The sequence length per sequence. Sequence length means the computed
# tokens + new tokens (is None if it is a decoding).
@@ -168,6 +216,10 @@ class AscendMetadata:
# *************************** Other Properties *************************** #
enable_dbo_across_dp: bool = False
prefill: Optional[AscendMetadataForPrefill] = None
decode_meta: Optional[AscendMetadataForDecode] = None
class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
@@ -207,10 +259,25 @@ class AscendAttentionMetadataBuilder:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
decode_threshold = 1
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
@@ -218,7 +285,7 @@ class AscendAttentionMetadataBuilder:
+ 1]
if attn_state == AscendAttentionState.DecodeOnly and \
common_attn_metadata.num_input_tokens > num_actual_tokens:
common_attn_metadata.num_input_tokens > num_actual_tokens:
padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens
seq_lens = torch.cat([
seq_lens,
@@ -252,8 +319,51 @@ class AscendAttentionMetadataBuilder:
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
prefill_metadata = None
if num_prefills > 0:
pcp_metadata = None
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
pcp_metadata = AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
kv_with_q_head_nomask_idx=common_long_seq_metadata.
kv_with_q_head_nomask_idx_tensor,
kv_with_q_head_mask_idx=common_long_seq_metadata.
kv_with_q_head_mask_idx_tensor,
kv_with_q_tail_nomask_idx=common_long_seq_metadata.
kv_with_q_tail_nomask_idx_tensor,
kv_with_q_tail_mask_idx=common_long_seq_metadata.
kv_with_q_tail_mask_idx_tensor,
attn_mask_seqlens=common_long_seq_metadata.
attn_mask_seqlens,
head_attn_nomask_seqlens=common_long_seq_metadata.
head_attn_nomask_seqlens,
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx
if common_long_seq_metadata is not None else None)
decode_metadata = None
if num_decodes > 0:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
num_computed_tokens_of_pcp_dcp = np.array(
num_computed_tokens_of_pcp_dcp)
decode_metadata = AscendMetadataForDecode(
num_computed_tokens_of_pcp_dcp=
num_computed_tokens_of_pcp_dcp)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens,
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
@@ -264,7 +374,11 @@ class AscendAttentionMetadataBuilder:
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
num_prefills=num_prefills,
num_decodes=num_decodes,
prefill=prefill_metadata,
decode_meta=decode_metadata)
return attn_metadata
def build_for_graph_capture(
@@ -322,6 +436,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.key_cache = None
self.value_cache = None
self.torch_npu_check = version_check()
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.pcp_group = get_pcp_group(
).device_group if self.pcp_size > 1 else None
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group(
).device_group if self.dcp_size > 1 else None
def _forward_prefill_no_cache(
self,
@@ -581,6 +707,236 @@ class AscendAttentionBackendImpl(AttentionImpl):
out=output)
return output
def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor,
lengths: List[int]) -> torch.Tensor:
max_len = max(lengths)
splits = torch.split(tensor_tnd, lengths, dim=0)
padded = []
for s in splits:
pad_len = max_len - s.shape[0]
s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len))
padded.append(s_pad)
tensor_bsnd = torch.stack(padded, dim=0)
return tensor_bsnd
def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor,
lengths: List[int]) -> torch.Tensor:
slices = []
for i, length in enumerate(lengths):
slices.append(tensor_bsnd[i, :length])
tensor_tnd = torch.cat(slices, dim=0)
return tensor_tnd
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
q_seqlens: List[int],
k_nomask: torch.Tensor,
v_nomask: torch.Tensor,
kv_seqlens_nomask: List[int],
k_mask: torch.Tensor,
v_mask: torch.Tensor,
kv_seqlens_mask: List[int],
mask: torch.Tensor) -> torch.Tensor:
q = self._pack_tnd_2_bsnd(q, q_seqlens)
# nomask Attention
if k_nomask is not None:
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
q,
self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask),
self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask),
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_nomask,
actual_seq_lengths=q_seqlens)
attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask,
q_seqlens)
# (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1)
attn_lse_nomask = self._unpack_bsnd_2_tnd(
attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens)
# mask Attention
attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score(
q,
self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask),
self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask),
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
atten_mask=mask,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_mask,
actual_seq_lengths=q_seqlens)
attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens)
attn_lse_mask = self._unpack_bsnd_2_tnd(
attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens)
# update
output = attn_out_mask
if k_nomask is not None:
output, _ = self._update_out_and_lse(
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
return output
def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AscendMetadata) -> torch.Tensor:
assert attn_metadata is not None
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.pcp_metadata is not None
# Use precomputed indices from the metadata (already converted to tensors and on device)
q_head_idx = attn_metadata.prefill.pcp_metadata.q_head_idx
q_tail_idx = attn_metadata.prefill.pcp_metadata.q_tail_idx
kv_with_q_head_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_nomask_idx
kv_with_q_head_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_head_mask_idx
kv_with_q_tail_nomask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_nomask_idx
kv_with_q_tail_mask_idx = attn_metadata.prefill.pcp_metadata.kv_with_q_tail_mask_idx
attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
# 1. Attention calculation in the first half of Q in load balancing
output_head = self._attention_with_nomask_and_mask(
q=torch.index_select(query, 0, q_head_idx),
q_seqlens=attn_mask_seqlens[0].tolist(),
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
if self.pcp_rank > 0 else None,
v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx)
if self.pcp_rank > 0 else None,
kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(),
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
mask=mask)
# 2. the Attention calculation in the latter half of Q in load balancing
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
output_tail = self._attention_with_nomask_and_mask(
q=torch.index_select(query, 0, q_tail_idx),
q_seqlens=attn_mask_seqlens[0].tolist(),
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx),
kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(),
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
kv_seqlens_mask=attn_mask_seqlens[0].tolist(),
mask=mask)
# 3. Combine the output of the first half and second half.
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
return output
def _update_out_and_lse(self, out_list: torch.Tensor,
lse_list: torch.Tensor) -> torch.Tensor:
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
Args:
out_list: shape = [N, batch_size, num_heads, head_size]
lse_list: shape = [N, batch_size, num_heads, 1]
Returns:
out_final: shape = [batch_size, num_heads, head_size]
lse_final: shape = [batch_size, num_heads, 1]
"""
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
dim=0)
return out_final, lse_final
def _forward_decode_pcp_dcp(self, query: torch.Tensor,
attn_metadata: AscendMetadata) -> torch.Tensor:
assert self.key_cache is not None
assert self.value_cache is not None
if self.dcp_size > 1:
query = get_dcp_group().all_gather(query, 1)
num_heads = self.num_heads * self.dcp_size
else:
num_heads = self.num_heads
# 1. Compute out&lse by "npu_fused_infer_attention_score"
attn_out, attn_lse = torch.ops.npu.npu_fused_infer_attention_score(
query.view(query.shape[0], 1, query.shape[1], query.shape[2]),
# [b,num_heads,head_size] -> [b,1,num_heads,head_size]
self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1),
self.value_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1),
num_heads=num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="BSND",
atten_mask=None,
scale=self.scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=attn_metadata.block_tables,
block_size=self.key_cache.shape[1],
actual_seq_lengths_kv=attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
)
attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
attn_out.shape[3])
attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
if self.dcp_size > 1:
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
attn_out_lse_split_on_seq = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
attn_out_lse_split_dcp = torch.stack(
attn_out_lse_split_on_seq,
dim=0) # [dcp, batch_size, num_heads, head_size+1]
# Update out&lse
attn_out_split_dcp, attn_lse_split_dcp = torch.split(
attn_out_lse_split_dcp, [self.head_size, 1], dim=-1)
attn_out, attn_lse = self._update_out_and_lse(
attn_out_split_dcp, attn_lse_split_dcp)
if self.pcp_size > 1:
# 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1]
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
# 3. AllGather out&lse within CP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
# 4. Update out&lse
attn_out_lse_allgather = torch.stack(
attn_out_lse_list,
dim=0) # [pcp, batch_size, num_heads, head_size+1]
attn_out_allgather, attn_lse_allgather = torch.split(
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
attn_out, _ = self._update_out_and_lse(attn_out_allgather,
attn_lse_allgather)
return attn_out
def forward(
self,
layer: AttentionLayer,
@@ -633,7 +989,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
else:
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
num_actual_tokens = attn_metadata.num_actual_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
@@ -650,14 +1009,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
if len(kv_cache) > 1:
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if attn_type == AttentionType.ENCODER_ONLY:
if has_decode:
slot_mapping = attn_metadata.slot_mapping[:num_decode_tokens * self.pcp_size: self.pcp_size] \
if self.pcp_size * self.dcp_size > 1 else attn_metadata.slot_mapping[:num_decode_tokens]
torch_npu._npu_reshape_and_cache(
key=key[:num_decode_tokens],
value=value[:num_decode_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slot_mapping)
if has_prefill:
if self.pcp_size > 1:
kv = torch.cat([key, value], dim=-1)
all_kv = get_pcp_group().all_gather(kv, dim=0)
pcp_allgather_restore_idx = attn_metadata.prefill.pcp_allgather_restore_idx if attn_metadata.prefill else None
all_kv = torch.index_select(all_kv, 0,
pcp_allgather_restore_idx)
key, value = all_kv.split(
[self.head_size, self.head_size], dim=-1)
torch_npu._npu_reshape_and_cache(
key=key[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
value=value[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=attn_metadata.
slot_mapping[self.pcp_size *
num_decode_tokens:attn_metadata.
num_actual_tokens_pcp_padded])
if self.pcp_size * self.dcp_size > 1:
output = self._forward_pcp_dcp(query, key, value,
attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
attn_out = torch_npu.npu_fusion_attention(
query,
@@ -668,7 +1059,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale=self.scale,
sparse_mode=4,
atten_mask=attn_metadata.attn_mask,
pre_tockens=attn_metadata.max_query_len,
next_tockens=attn_metadata.max_query_len,
actual_seq_qlen=cum_seq_len,
actual_seq_kvlen=cum_seq_len,
@@ -679,7 +1069,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
AscendAttentionState.PrefillCacheHit:
output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
@@ -701,6 +1091,46 @@ class AscendAttentionBackendImpl(AttentionImpl):
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
return output.view(num_tokens, self.hidden_size)
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: Optional[torch.Tensor]) -> torch.Tensor:
assert attn_metadata is not None
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
if output is None:
raise ValueError("Output buffer is required")
if has_decode:
decode_query = query[:num_decode_tokens]
output_decode = self._forward_decode_pcp_dcp(
decode_query, attn_metadata)
output[:num_decode_tokens] = output_decode
if has_prefill:
prefill_query = query[num_decode_tokens:]
key = key[self.pcp_size * num_decode_tokens:]
value = value[self.pcp_size * num_decode_tokens:]
if self.pcp_size > 1:
output_prefill = self._forward_prefill_cp(
prefill_query, key, value, attn_metadata)
else:
max_prefill_seq_len = attn_metadata.seq_lens[
attn_metadata.num_decode_tokens:].max().item()
if attn_metadata.attn_mask is not None:
attn_metadata.attn_mask = attn_metadata.attn_mask[:
max_prefill_seq_len, :
max_prefill_seq_len]
else:
ValueError("Attn_metadata.attn_mask is required")
seq_lens_back = attn_metadata.seq_lens
attn_metadata.seq_lens = attn_metadata.seq_lens[
attn_metadata.num_decode_tokens:]
output_prefill = self._forward_prefill_no_cache(
prefill_query, key, value, attn_metadata,
output[num_decode_tokens:], prefill_query.shape[0])
attn_metadata.seq_lens = seq_lens_back
output[num_decode_tokens:] = output_prefill
return output
def unified_ascend_attention_with_output(
query: torch.Tensor,