[feature] chunkprefill support pcp&dcp (#3801)

### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI tests passed with self-test


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
Apocalypse
2025-11-11 09:18:02 +08:00
committed by GitHub
parent 7ffbe73d54
commit 71866d5311
8 changed files with 1276 additions and 170 deletions

View File

@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend,
@@ -27,11 +28,14 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
# isort: off
from vllm_ascend.attention.utils import (
AscendCommonAttentionMetadata, extract_req_dcp_by_chunk_pcp,
filter_chunked_req_indices, maybe_save_kv_layer_to_connector,
split_decodes_and_prefills, trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
# isort: on
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
@@ -111,6 +115,10 @@ class AscendMLAPrefillMetadata:
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor
mask_for_non_zero_chunk: Optional[list[bool]] = None
max_chunk_num: int = 0
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
Optional[list[int]]]]]]]] = None
attn_mask: torch.Tensor
query_lens: torch.Tensor
@@ -125,6 +133,7 @@ class AscendMLAPrefillMetadata:
sin: torch.Tensor = None
cos: torch.Tensor = None
pcp_metadata: Optional[AscendPCPMetadata] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
@dataclass
@@ -347,6 +356,10 @@ class AscendMLAMetadataBuilder:
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
cp_kv_recover_idx_for_chunk = long_seq_metadata.cp_kv_recover_idx_for_chunk if long_seq_metadata else None
local_chunked_kv_lens = long_seq_metadata.local_chunked_kv_lens if long_seq_metadata else None
mask_for_non_zero_chunk = long_seq_metadata.mask_for_non_zero_chunk if long_seq_metadata else None
max_chunk_num = long_seq_metadata.max_chunk_num if long_seq_metadata else 0
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
@@ -359,14 +372,15 @@ class AscendMLAMetadataBuilder:
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
if num_actual_tokens_pcp_padded is None:
num_actual_tokens_pcp_padded = num_actual_tokens
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens_pcp_padded]
input_positions = common_attn_metadata.positions[:
num_actual_tokens_pcp_padded].long(
)
if self.cos_cache is None:
self.cos_cache = model.model.layers[
@@ -408,7 +422,8 @@ class AscendMLAMetadataBuilder:
tail_attn_nomask_seqlens=common_long_seq_metadata.
tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask
if long_seq_metadata else None,
pcp_allgather_restore_idx=long_seq_metadata.
pcp_allgather_restore_idx if long_seq_metadata else None)
@@ -452,6 +467,9 @@ class AscendMLAMetadataBuilder:
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
local_chunked_kv_lens=local_chunked_kv_lens,
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
max_chunk_num=max_chunk_num,
)
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
@@ -474,7 +492,7 @@ class AscendMLAMetadataBuilder:
sin=sin,
cos=cos,
pcp_metadata=pcp_metadata,
)
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk)
decode_metadata = None
if num_decodes > 0:
@@ -887,8 +905,26 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_metadata = attn_metadata.prefill
if prefill_metadata is None or prefill_metadata.chunked_context is None:
return prefix_output, prefix_lse
local_chunked_kv_lens = prefill_metadata.chunked_context.local_chunked_kv_lens
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
if self.pcp_size * self.dcp_size > 1:
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
if self.pcp_size > 1:
prefix_output = torch.zeros(q_nope.shape[0],
self.num_heads,
self.v_head_dim,
dtype=q_nope.dtype,
device=q_nope.device)
prefix_lse = torch.zeros(self.num_heads,
q_pe.shape[0],
dtype=torch.float32,
device=q_pe.device)
iters = len(prefill_metadata.chunked_context.seq_tot)
if self.pcp_size * self.dcp_size > 1:
iters = max_chunk_num
current_seq_len = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32)
@@ -896,60 +932,305 @@ class AscendMLAImpl(MLAAttentionImpl):
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
# token -> request mapping for building per-token masks when CP>1
seq_len1 = torch.tensor(prefill_metadata.query_lens,
dtype=torch.int32,
device=q_nope.device)
seq_len1.mul_(
self.pcp_size) # q_full: already padded, divisible by cp_size
# Select mask: prefer CP prefill mask from metadata; fallback to cached prefill_mask; create if needed.
mask_local = None
if attn_metadata is not None and attn_metadata.prefill is not None and \
attn_metadata.prefill.pcp_metadata is not None and attn_metadata.prefill.pcp_metadata.pcp_prefill_mask is not None:
mask_local = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
else:
mask_local = self.prefill_mask
if mask_local is None:
mask_local = torch.triu(
torch.ones(512,
512,
device=q_nope.device,
dtype=q_nope.dtype), 1)
self.prefill_mask = mask_local
# Keep the causal mask; do not override to all-ones.
context_starts_rank = None
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
if self.pcp_size * self.dcp_size > 1:
## DCP mode: each rank processes its own (cp,dcp) historical context slice per request dimension
num_requests = len(seq_len1)
assert num_requests == len(local_chunked_kv_lens)
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
context_starts_rank = torch.zeros(
num_requests, dtype=torch.int32, device=q_nope.device
) if context_starts_rank is None else context_starts_rank
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(toks,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)
## Calculate tokens each rank should process per request
seq_len2_rank = torch.zeros_like(seq_len1, dtype=torch.int32)
total_toks = 0
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
)
for req_idx in range(num_requests):
if i >= len(local_chunked_kv_lens[req_idx]):
continue
n_computed_acc = local_chunked_kv_lens[req_idx][i]
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
seq_len2_rank[req_idx] = n_computed_acc[self.pcp_rank][
self.dcp_rank]
if total_toks > 0:
kv_c_normed = torch.empty(total_toks,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(total_toks,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2_rank.to(q_nope.device),
seq_starts=
context_starts_rank, # slot offsets of current chunk in current iteration
key=kv_c_normed,
value=k_pe,
)
seq_len2 = seq_len2_rank.to(q_nope.device)
else:
# If current rank has no tokens to process, create empty tensors
kv_c_normed = torch.empty(0,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(0,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)
seq_len2 = torch.zeros((len(seq_len1), ),
dtype=torch.int32,
device=q_nope.device)
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
for req_idx in range(num_requests):
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
if i >= len(local_chunked_kv_lens[req_idx]):
continue
context_starts_rank[req_idx] += local_chunked_kv_lens[
req_idx][i][self.pcp_rank][self.dcp_rank]
else:
# Original logic: ChunkPrefill-only mode
toks = prefill_metadata.chunked_context.seq_tot[i]
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
i]
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
dtype=q_nope.dtype,
device=q_nope.device)
k_pe = torch.empty(toks,
num_heads,
rope_dim,
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
)
kv_c_normed = kv_c_normed.squeeze()
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=self.prefill_mask,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
if self.dcp_size > 1:
# DCP mode: first all_gather within DCP group, let each rank in CP group share complete sequence blocks
# Step 1: DCP all_gather latent
kv_c_k_pe_local = torch.cat(
[kv_c_normed, k_pe.squeeze()],
dim=-1) # [local_toks, latent_dim + rope_dim]
# Step 2: use all_gather_into_tensor_uneven (gather + cat)
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank
) # need to know num tokens of each rank in dcp group before all_gather # [reqs, dcp]
assert len(req_dcp_sizes) == num_requests and all(
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
total_toks = np.sum(np.array(req_dcp_sizes))
latent_rope_dim = kv_c_k_pe_local.size(-1)
kv_c_k_pe_full = torch.empty((total_toks, latent_rope_dim),
device=kv_c_k_pe_local.device,
dtype=kv_c_k_pe_local.dtype)
kv_c_k_pe_full_list = [None for _ in range(self.dcp_size)]
dist.all_gather_object(kv_c_k_pe_full_list,
kv_c_k_pe_local,
group=self.dcp_group)
kv_c_k_pe_full_list = [
kv_c_k_pe for kv_c_k_pe in kv_c_k_pe_full_list
if kv_c_k_pe is not None and kv_c_k_pe.numel() > 0
]
if len(kv_c_k_pe_full_list) > 0:
kv_c_k_pe_full = torch.cat(kv_c_k_pe_full_list, dim=0)
if len(kv_c_k_pe_full.shape) == 1:
assert total_toks == 1
kv_c_k_pe_full = kv_c_k_pe_full.unsqueeze(0)
assert kv_c_k_pe_full.shape[
0] == total_toks and kv_c_k_pe_full.shape[
1] == latent_rope_dim
kv_c_normed_full, k_pe_full = torch.split(
kv_c_k_pe_full, [latent_kv_dim, rope_dim], dim=-1)
# Step 3: process complete sequence with TP projection to get current rank's head slice
# Case that no kv_cache has been stored on this CP rank(after dcp all_gather), no need to do following computation.
if total_toks == 0:
continue
kv_nope = self.kv_b_proj(kv_c_normed_full)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe_full.unsqueeze(1).expand((*k_nope.shape[:-1], -1))
seq_len2 = torch.tensor(np.sum(np.array(req_dcp_sizes),
axis=1),
dtype=torch.int32,
device=q_nope.device) # [reqs]
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
else:
# Non-DCP mode: use TP-split projection
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
if self.pcp_size > 1:
# Case that no kv_cache has been stored on this CP rank, no need to do following computation.
if torch.all(seq_len2 == 0).item():
continue
# PCP mode: first compute this rank's contribution to the chunk
if i == 0:
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask_local,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=prefix_output,
softmax_lse=prefix_lse)
continue
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask_local,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
else:
assert not torch.all(context_seq_len == 0).item()
# compute this chunk block then update prefix tensors to keep shapes consistent
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask_local,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
# CP dimension all_gather and fusion
if self.pcp_size > 1:
# filter non-zero chunk part of prefix_output
seq_len1_cpu = seq_len1.cpu()
filtered_indices = filter_chunked_req_indices(
seq_len1_cpu, mask_for_non_zero_chunk)
prefix_output_filtered = prefix_output[filtered_indices, :, :]
prefix_lse_filtered = prefix_lse[:, filtered_indices]
# normalize prefix LSE to [bs, heads, 1] for stable updates
prefix_lse_filtered_bt = prefix_lse_filtered.permute(
1, 0).unsqueeze(-1).contiguous(
) if prefix_lse_filtered is not None else None
out_lse_local = torch.cat(
[prefix_output_filtered, prefix_lse_filtered_bt], dim=-1)
out_lse_list = [
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
]
dist.all_gather(out_lse_list, out_lse_local, group=self.pcp_group)
prefix_output_filtered = None
prefix_lse_filtered_bt = None
for r in range(self.pcp_size):
out_lse_r = out_lse_list[r]
if torch.all(out_lse_r == 0).item():
continue
out_r, lse_r = torch.split(out_lse_r, [self.v_head_dim, 1],
dim=-1)
token_mask = torch.ones([out_r.size(0)],
dtype=torch.uint8,
device=out_r.device)
prefix_output_filtered, prefix_lse_filtered_bt = self._update_out_and_lse(
prefix_output_filtered, prefix_lse_filtered_bt, out_r,
lse_r, token_mask)
# convert lse back to [heads, bs]
assert prefix_output_filtered is not None and prefix_lse_filtered_bt is not None
prefix_lse_filtered = prefix_lse_filtered_bt.squeeze(-1).permute(
1, 0).contiguous()
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
prefix_output.dtype)
prefix_lse[:, filtered_indices] = prefix_lse_filtered.to(
prefix_lse.dtype)
return prefix_output, prefix_lse
def _forward_prefill(
@@ -1516,7 +1797,7 @@ class AscendMLAImpl(MLAAttentionImpl):
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
output_head = self._attention_with_mask_and_nomask(
output_head, head_lse = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_head_idx),
q_pe=torch.index_select(q_pe, 0, q_head_idx),
k_nope=k_nope,
@@ -1528,7 +1809,7 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_nomask_seqlens=head_attn_nomask_seqlens,
mask=mask)
output_tail = self._attention_with_mask_and_nomask(
output_tail, tail_lse = self._attention_with_mask_and_nomask(
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
k_nope=k_nope,
@@ -1544,7 +1825,83 @@ class AscendMLAImpl(MLAAttentionImpl):
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
# Synchronize and reorder LSE for subsequent chunked context accumulation
attn_lse = torch.cat([head_lse, tail_lse], dim=1)
attn_lse = attn_lse[:, q_full_idx]
# Post-processing: keep [tokens, H, V] shape and perform chunked context accumulation if needed
if attn_metadata.prefill is not None and \
attn_metadata.prefill.chunked_context is not None:
# q all_gather
q_nope_full = get_pcp_group().all_gather(q_nope.contiguous(), 0)
q_pe_full = get_pcp_group().all_gather(q_pe.contiguous(), 0)
q_nope_full = torch.index_select(
q_nope_full, 0,
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
q_pe_full = torch.index_select(
q_pe_full, 0,
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
attn_output_pre = output.view(num_tokens, self.num_heads,
self.v_head_dim)
attn_output_pre_full, attn_lse_full = self._compute_prefill_context(
q_nope_full,
q_pe_full,
kv_c_and_k_pe_cache,
self.qk_rope_head_dim,
attn_metadata,
None,
None,
)
# reorder back && extract output + lse result of each cp rank
inverse_idx = torch.argsort(
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
attn_output_pre_full = torch.index_select(attn_output_pre_full, 0,
inverse_idx)
attn_lse_full = torch.index_select(attn_lse_full, 1, inverse_idx)
attn_output_pre_new = attn_output_pre_full[
self.pcp_rank * num_tokens:(self.pcp_rank + 1) *
num_tokens, :, :]
attn_lse_new = attn_lse_full[:, self.pcp_rank *
num_tokens:(self.pcp_rank + 1) *
num_tokens]
# update(output_origin, output_new)
assert attn_output_pre_new.shape == attn_output_pre.shape and attn_lse_new.shape == attn_lse.shape
seq_len = torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32)
mask_for_non_zero_chunk = attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk
filtered_indices = filter_chunked_req_indices(
seq_len, mask_for_non_zero_chunk)
attn_output_pre_filtered = attn_output_pre[filtered_indices, :, :]
attn_lse_filtered = attn_lse[:, filtered_indices]
attn_output_pre_new = attn_output_pre_new[filtered_indices, :, :]
attn_lse_new = attn_lse_new[:, filtered_indices]
# normalize prefix LSE to [bs, heads, 1] for stable updates
attn_lse_filtered = attn_lse_filtered.permute(1, 0).unsqueeze(-1)
attn_lse_new = attn_lse_new.permute(1, 0).unsqueeze(-1)
token_mask = torch.ones([attn_lse_new.size(0)],
dtype=torch.uint8,
device=attn_lse_new.device)
attn_output_pre_filtered, attn_lse_filtered = self._update_out_and_lse(
attn_output_pre_filtered, attn_lse_filtered,
attn_output_pre_new, attn_lse_new, token_mask)
# convert lse back to [heads, bs]
attn_lse_filtered = attn_lse_filtered.squeeze(-1).permute(
1, 0).contiguous()
attn_output_pre[
filtered_indices, :, :] = attn_output_pre_filtered.to(
attn_output_pre.dtype)
attn_lse[:,
filtered_indices] = attn_lse_filtered.to(attn_lse.dtype)
attn_output_pre = attn_output_pre.to(q_nope.dtype)
output = attn_output_pre.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
else:
output = output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
return output
@@ -1588,7 +1945,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# nomask
if kv_nomask_idx.shape[0] == 0:
return attn_output
return attn_output, attn_lse
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
@@ -1611,7 +1968,7 @@ class AscendMLAImpl(MLAAttentionImpl):
calc_type="calc_type_default",
output=attn_output,
softmax_lse=attn_lse)
return attn_output
return attn_output, attn_lse
def _forward_decode_pcp_dcp(
self,
@@ -1788,3 +2145,33 @@ class AscendMLAImpl(MLAAttentionImpl):
attn_out_lse_list = attn_out_lse_list_pcp_dcp
return attn_out_lse_list
# TODO use update op to replace this
def _update_out_and_lse(
self,
out: torch.Tensor,
lse: torch.Tensor,
block_out: torch.Tensor,
block_lse: torch.Tensor,
mask: torch.Tensor = None,
):
if out is None:
out = block_out.to(torch.float32)
lse = block_lse
else:
if mask is None:
mask = torch.ones([block_out.size(0)],
dtype=torch.uint8,
device=block_out.device)
out_mask = mask[:, None, None].expand_as(block_out)
lse_mask = mask[:, None, None].expand_as(block_lse)
block_out = block_out.to(torch.float32)
out_without_update = out.clone()
lse_without_update = lse.clone()
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
# mask
out = torch.where(out_mask, out, out_without_update)
lse = torch.where(lse_mask, lse, lse_without_update)
return out, lse