[feature] chunkprefill support pcp&dcp (#3801)

### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI tests passed with self-test


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
Apocalypse
2025-11-11 09:18:02 +08:00
committed by GitHub
parent 7ffbe73d54
commit 71866d5311
8 changed files with 1276 additions and 170 deletions

View File

@@ -37,6 +37,8 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
extract_req_dcp_by_chunk_pcp,
filter_chunked_req_indices,
split_decodes_and_prefills)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
@@ -52,6 +54,7 @@ if prefill_context_parallel_enable():
get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort: on
@@ -155,9 +158,23 @@ class AscendPCPMetadata:
@dataclass
class AscendMetadataForPrefill:
@dataclass
class ChunkedContextMetadata:
actual_chunk_seq_lengths: list[int]
mask_for_non_zero_chunk: Optional[list[bool]] = None
max_chunk_num: int = 0
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
Optional[list[int]]]]]]]] = None
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
kv_inverse_idx_for_chunk: Optional[list[int]] = None
""" Prefill Specific Metadata for Ascend"""
pcp_metadata: Optional[AscendPCPMetadata] = None
pcp_allgather_restore_idx: Optional[List[int]] = None
chunked_context: Optional[ChunkedContextMetadata] = None
block_tables: torch.Tensor = None
actual_seq_lengths_q: torch.Tensor = None
@dataclass
@@ -165,6 +182,7 @@ class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
batch_seq_mask: torch.Tensor = None
block_tables: torch.Tensor = None
@dataclass
@@ -237,13 +255,10 @@ class AscendAttentionMetadataBuilder:
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len,
AscendAttentionBackend.get_supported_block_size()[0])
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
'decode_max_num_seqs', 0)
max_num_seqs = max(vllm_config.scheduler_config.max_num_seqs,
decode_max_num_seqs)
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
dtype=torch.uint8,
device=device)
self.batch_seq_mask_buf = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.uint8,
device=device)
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
@@ -263,6 +278,27 @@ class AscendAttentionMetadataBuilder:
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
return False
@@ -279,9 +315,8 @@ class AscendAttentionMetadataBuilder:
num_reqs
+ 1]
decode_threshold = 1
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
@@ -302,6 +337,7 @@ class AscendAttentionMetadataBuilder:
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
num_computed_tokens_cpu = (seq_lens - query_lens)
if attn_state == AscendAttentionState.DecodeOnly and \
common_attn_metadata.num_input_tokens > num_actual_tokens:
@@ -338,16 +374,35 @@ class AscendAttentionMetadataBuilder:
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
prefill_metadata = None
if num_prefills > 0:
pcp_metadata = None
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
decode_metadata = None
if common_long_seq_metadata is not None:
chunked_context_metadata = None
if num_prefills > 0:
query_lens = query_lens[num_decode_tokens:]
context_lens_cpu = num_computed_tokens_cpu[
num_decodes:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk
kv_inverse_idx_for_chunk = torch.argsort(
cp_kv_recover_idx_for_chunk.to(torch.float32)
) if cp_kv_recover_idx_for_chunk is not None else None
chunked_context_metadata = \
AscendMetadataForPrefill.ChunkedContextMetadata(
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
mask_for_non_zero_chunk=common_long_seq_metadata.mask_for_non_zero_chunk,
local_chunked_kv_lens=common_long_seq_metadata.local_chunked_kv_lens,
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
max_chunk_num=common_long_seq_metadata.max_chunk_num
)
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
if pcp_size > 1:
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0],
dim=0).tolist()
@@ -355,6 +410,7 @@ class AscendAttentionMetadataBuilder:
head_attn_nomask_seqlens[1], dim=0).tolist()
tail_attn_nomask_seqlens = torch.cumsum(
tail_attn_nomask_seqlens[1], dim=0).tolist()
pcp_metadata = AscendPCPMetadata(
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
@@ -371,16 +427,17 @@ class AscendAttentionMetadataBuilder:
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
q_full_idx=common_long_seq_metadata.q_full_idx,
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx
if common_long_seq_metadata is not None else None)
decode_metadata = None
if num_decodes > 0:
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata,
pcp_allgather_restore_idx=common_long_seq_metadata.
pcp_allgather_restore_idx
if common_long_seq_metadata is not None else None,
chunked_context=chunked_context_metadata,
block_tables=block_table[num_decodes:],
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0))
if num_decodes > 0:
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
assert num_computed_tokens_of_pcp_dcp is not None
num_computed_tokens_array = np.array(
@@ -397,7 +454,7 @@ class AscendAttentionMetadataBuilder:
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask.
shape[0]],
)
block_tables=block_table[:num_decodes])
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
@@ -751,7 +808,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
k_mask: torch.Tensor,
v_mask: torch.Tensor,
kv_seqlens_mask: List[int],
mask: torch.Tensor) -> torch.Tensor:
mask: torch.Tensor,
attn_metadata) -> torch.Tensor:
# nomask Attention
if k_nomask is not None:
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
@@ -786,30 +844,42 @@ class AscendAttentionBackendImpl(AttentionImpl):
softmax_lse_flag=True,
actual_seq_lengths_kv=kv_seqlens_mask,
actual_seq_lengths=q_seqlens)
# update
output = attn_out_mask
attn_lse = attn_lse_mask
if k_nomask is not None:
T = attn_out_mask.shape[0]
N = attn_out_mask.shape[1]
D = attn_out_mask.shape[2]
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None:
output = self._npu_attn_out_lse_update(attn_lse_mask,
attn_lse_nomask,
attn_out_mask,
attn_out_nomask)
attn_lse = None
else:
output, attn_lse = self._update_out_and_lse(
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
attn_out_mask, attn_lse_mask)
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
attn_out_nomask, attn_lse_nomask)
attn_out_mask = attn_out_mask.to(torch.float32)
attn_out_nomask = attn_out_nomask.to(torch.float32)
attn_lse_mask = attn_lse_mask.to(torch.float32)
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
attn_output = [attn_out_nomask, attn_out_mask]
attn_lse = [attn_lse_nomask, attn_lse_mask]
update_type = 0
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
update_type)
output = output.view(T, N, D)
return output, attn_lse
def _npu_attn_out_lse_update(self, attn_lse_mask, attn_lse_nomask,
attn_out_mask, attn_out_nomask):
T = attn_out_mask.shape[0]
N = attn_out_mask.shape[1]
D = attn_out_mask.shape[2]
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
attn_out_mask, attn_lse_mask)
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
attn_out_nomask, attn_lse_nomask)
attn_out_mask = attn_out_mask.to(torch.float32)
attn_out_nomask = attn_out_nomask.to(torch.float32)
attn_lse_mask = attn_lse_mask.to(torch.float32)
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
attn_output = [attn_out_nomask, attn_out_mask]
attn_lse = [attn_lse_nomask, attn_lse_mask]
update_type = 0
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
update_type)
output = output.view(T, N, D)
return output
def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
@@ -831,7 +901,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
# 1. Attention calculation in the first half of Q in load balancing
output_head = self._attention_with_nomask_and_mask(
output_heads, lse_heads = self._attention_with_nomask_and_mask(
q=torch.index_select(query, 0, q_head_idx),
q_seqlens=attn_mask_seqlens,
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
@@ -842,12 +912,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
kv_seqlens_mask=attn_mask_seqlens,
mask=mask)
mask=mask,
attn_metadata=attn_metadata)
# 2. the Attention calculation in the latter half of Q in load balancing
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
output_tail = self._attention_with_nomask_and_mask(
output_tails, lse_tails = self._attention_with_nomask_and_mask(
q=torch.index_select(query, 0, q_tail_idx),
q_seqlens=attn_mask_seqlens,
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
@@ -856,13 +927,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
kv_seqlens_mask=attn_mask_seqlens,
mask=mask)
mask=mask,
attn_metadata=attn_metadata)
# 3. Combine the output of the first half and second half.
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
output = torch.index_select(
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
return output
torch.cat([output_heads, output_tails], dim=0), 0, q_full_idx)
attn_lse = None
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
attn_lse = torch.index_select(
torch.cat([lse_heads, lse_tails], dim=0), 0, q_full_idx)
return output, attn_lse
def _out_lse_reshape(self, attn_out: torch.Tensor,
attn_lse: torch.Tensor) -> torch.Tensor:
@@ -928,7 +1003,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
'softmax_lse_flag':
True,
'block_table':
attn_metadata.block_tables,
attn_metadata.decode_meta.block_tables,
'block_size':
self.key_cache.shape[1],
'actual_seq_lengths_kv':
@@ -1029,8 +1104,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_out = self._npu_attention_update(attn_out_lse_list)
return attn_out
def _update_out_and_lse(self, out_list: torch.Tensor,
lse_list: torch.Tensor) -> torch.Tensor:
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
Args:
out_list: shape = [N, batch_size, num_heads, head_size]
lse_list: shape = [N, batch_size, num_heads, 1]
Returns:
out_final: shape = [batch_size, num_heads, head_size]
lse_final: shape = [batch_size, num_heads, 1]
"""
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
dim=0)
return out_final, lse_final
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
value: torch.Tensor, kv_cache: Tuple[torch.Tensor],
attn_metadata: AscendMetadata,
output: torch.Tensor) -> torch.Tensor:
assert attn_metadata is not None
has_decode = attn_metadata.num_decodes > 0
@@ -1043,32 +1134,320 @@ class AscendAttentionBackendImpl(AttentionImpl):
decode_query, attn_metadata)
output[:num_decode_tokens] = output_decode
if has_prefill:
prefill_query = query[num_decode_tokens:]
assert attn_metadata.prefill is not None
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
prefill_query = query[
num_decode_tokens:num_actual_tokens_pcp_padded]
key = key[self.pcp_size * num_decode_tokens:]
value = value[self.pcp_size * num_decode_tokens:]
if self.pcp_size > 1:
output_prefill = self._forward_prefill_cp(
# Scenario of Enabling PCP or PCP&DCP
attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp(
prefill_query, key, value, attn_metadata)
else:
max_prefill_seq_len = attn_metadata.seq_lens[
attn_metadata.num_decode_tokens:].max().item()
if attn_metadata.attn_mask is not None:
attn_metadata.attn_mask = attn_metadata.attn_mask[:
max_prefill_seq_len, :
max_prefill_seq_len]
else:
ValueError("Attn_metadata.attn_mask is required")
seq_lens_back = attn_metadata.seq_lens
attn_metadata.seq_lens = attn_metadata.seq_lens[
attn_metadata.num_decode_tokens:]
output_prefill = self._forward_prefill_no_cache(
prefill_query, key, value, attn_metadata,
output[num_decode_tokens:], prefill_query.shape[0])
attn_metadata.seq_lens = seq_lens_back
output[num_decode_tokens:output_prefill.shape[0] +
num_decode_tokens] = output_prefill
# Scenario of Enabling DCP Individually
attn_output_prefill, attn_lse_prefill = torch.ops.npu.npu_fused_infer_attention_score(
prefill_query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND",
atten_mask=attn_metadata.attn_mask,
scale=self.scale,
sparse_mode=3,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=attn_metadata.prefill.
actual_seq_lengths_q,
actual_seq_lengths=attn_metadata.prefill.
actual_seq_lengths_q)
self._process_chunk_prefill(attn_output_prefill, attn_lse_prefill,
kv_cache, prefill_query, attn_metadata)
output[num_decode_tokens:attn_output_prefill.shape[0] +
num_decode_tokens] = attn_output_prefill
return output
def _process_chunk_prefill(self, current_attn_output_prefill,
current_attn_lse_prefill, kv_cache,
prefill_query, attn_metadata):
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
prefill_query_all = self._prefill_query_all_gather(
attn_metadata, prefill_query)
attn_output_full_chunk, attn_lse_full_chunk = self._compute_prefill_context(
prefill_query_all, kv_cache, attn_metadata)
self._update_chunk_attn_out_lse_with_current_attn_out_lse(
current_attn_output_prefill, current_attn_lse_prefill,
attn_output_full_chunk, attn_lse_full_chunk, prefill_query,
attn_metadata)
def _update_chunk_attn_out_lse_with_current_attn_out_lse(
self, current_attn_output_prefill, current_attn_lse_prefill,
attn_output_full_chunk, attn_lse_full_chunk, prefill_query,
attn_metadata):
if self.pcp_size > 1:
inverse_idx = attn_metadata.prefill.chunked_context.kv_inverse_idx_for_chunk
attn_output_full_chunk = torch.index_select(
attn_output_full_chunk, 0, inverse_idx)
attn_lse_full_chunk = torch.index_select(attn_lse_full_chunk, 0,
inverse_idx)
num_tokens = prefill_query.size(0)
attn_output_full_chunk = attn_output_full_chunk[
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
attn_lse_full_chunk = attn_lse_full_chunk[
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
seq_len = attn_metadata.query_lens.detach().clone()
filtered_indices = filter_chunked_req_indices(
seq_len,
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
attn_output_prefill_filtered = current_attn_output_prefill[
filtered_indices, :, :]
attn_lse_prefill_filtered = current_attn_lse_prefill[
filtered_indices, :, :]
attn_output_full_chunk = attn_output_full_chunk[filtered_indices, :, :]
attn_lse_full_chunk = attn_lse_full_chunk[filtered_indices, :, :]
attn_output_filtered = self._npu_attn_out_lse_update(
attn_lse_prefill_filtered, attn_lse_full_chunk,
attn_output_prefill_filtered, attn_output_full_chunk)
current_attn_output_prefill[
filtered_indices, :, :] = attn_output_filtered.to(
current_attn_output_prefill.dtype)
def _prefill_query_all_gather(self, attn_metadata, prefill_query):
prefill_query_all = get_pcp_group().all_gather(prefill_query.contiguous(),
0) \
if self.pcp_size > 1 else prefill_query
prefill_query_all = torch.index_select(prefill_query_all,
0,
attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk) \
if self.pcp_size > 1 else prefill_query_all
return prefill_query_all
def _compute_prefill_context(self, query: torch.Tensor,
kv_cache: Tuple[torch.Tensor],
attn_metadata: AscendMetadata):
assert len(kv_cache) > 1
assert attn_metadata is not None
assert attn_metadata.prefill is not None
assert attn_metadata.prefill.chunked_context is not None
prefill_metadata = attn_metadata.prefill
local_chunked_kv_lens = attn_metadata.prefill.chunked_context.local_chunked_kv_lens
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
iters = max_chunk_num
# Keep the causal mask; do not override to all-ones. [req_id][chunk_id][cp-rank][dcp_rank]
context_starts_rank = None
prefix_output_list = []
prefix_lse_list = []
for i in range(iters):
key, value, seq_lens_current_chunk_rank = self._load_kv_for_chunk(
attn_metadata, kv_cache, context_starts_rank, i,
local_chunked_kv_lens, prefill_metadata, query)
# 2. Attention computation
if seq_lens_current_chunk_rank is None or torch.all(
seq_lens_current_chunk_rank == 0).item():
prefix_output = torch.full(
(query.size(0), self.num_heads, self.head_size),
fill_value=0,
dtype=query.dtype,
device=query.device)
prefix_lse = torch.full((query.size(0), self.num_heads, 1),
fill_value=0,
dtype=torch.float32,
device=query.device)
else:
actual_seq_lengths_kv = torch.cumsum(
seq_lens_current_chunk_rank, dim=0).tolist()
prefix_output, prefix_lse = torch.ops.npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND", #
atten_mask=None,
scale=self.scale,
sparse_mode=0,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=attn_metadata.prefill.chunked_context.
actual_chunk_seq_lengths)
prefix_output_list.append(prefix_output)
prefix_lse_list.append(prefix_lse)
# 3. update attn-out & lse
prefix_output, prefix_lse = self._update_attn_out_lse_in_chunks(
prefix_output_list, prefix_lse_list)
self._update_attn_out_lse_in_pcp(attn_metadata, prefix_output,
prefix_lse)
return prefix_output, prefix_lse
def _update_attn_out_lse_in_chunks(self, prefix_output_list,
prefix_lse_list):
# update output and lse
if len(prefix_output_list) > 1:
prefix_output, prefix_lse = self._update_out_and_lse(
torch.stack(prefix_output_list, dim=0),
torch.stack(prefix_lse_list, dim=0))
else:
prefix_output = prefix_output_list[0]
prefix_lse = prefix_lse_list[0]
return prefix_output, prefix_lse
def _update_attn_out_lse_in_pcp(self, attn_metadata, prefix_output,
prefix_lse):
# CP dimension all_gather and fusion
if self.pcp_size > 1:
# filter non-zero chunk part of prefix_output
current_seq_lens = attn_metadata.query_lens.detach().clone()
current_seq_lens.mul_(self.pcp_size) # q_full
current_seq_lens_cpu = current_seq_lens.cpu()
filtered_indices = filter_chunked_req_indices(
current_seq_lens_cpu,
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
prefix_output_filtered = prefix_output[filtered_indices, :, :]
prefix_lse_filtered = prefix_lse[filtered_indices, :, :]
out_lse_local = torch.cat(
[prefix_output_filtered, prefix_lse_filtered], dim=-1)
attn_out_lse_list = [
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
out_lse_local,
group=self.pcp_group)
attn_out_lse_allgather = torch.stack(
attn_out_lse_list,
dim=0) # [pcp, batch_size, num_heads, head_size+1]
attn_out_allgather, attn_lse_allgather = torch.split(
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
prefix_output_filtered, prefix_lse_filtered = self._update_out_and_lse(
attn_out_allgather, attn_lse_allgather)
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
prefix_output.dtype)
prefix_lse[filtered_indices, :, :] = prefix_lse_filtered.to(
prefix_lse.dtype)
def _load_kv_for_chunk(self, attn_metadata, kv_cache, context_starts_rank,
i, local_chunked_kv_lens, prefill_metadata, query):
cache_key = kv_cache[0]
cache_value = kv_cache[1]
num_heads = cache_key.size(2)
head_size = kv_cache[0].size(-1)
# 1. Load current query's history key-value
seq_lens_current_chunk = attn_metadata.query_lens.detach().clone()
num_requests = len(seq_lens_current_chunk)
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
context_starts_rank = torch.zeros(
num_requests, dtype=torch.int32, device=query.device
) if context_starts_rank is None else context_starts_rank
# Calculate tokens each rank should process per request
seq_lens_current_chunk_rank = torch.zeros_like(seq_lens_current_chunk,
dtype=torch.int32,
device=query.device)
total_toks = 0
for req_idx in range(num_requests):
if i >= len(local_chunked_kv_lens[req_idx]):
continue
n_computed_acc = local_chunked_kv_lens[req_idx][i]
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
seq_lens_current_chunk_rank[req_idx] = n_computed_acc[
self.pcp_rank][self.dcp_rank]
if total_toks > 0:
key = torch.empty(total_toks,
num_heads,
head_size,
dtype=query.dtype,
device=query.device)
value = torch.empty(total_toks,
num_heads,
head_size,
dtype=query.dtype,
device=query.device)
torch_npu.atb.npu_paged_cache_load(
cache_key,
cache_value,
attn_metadata.prefill.block_tables,
seq_lens_current_chunk_rank.to(query.device),
seq_starts=
context_starts_rank, # slot offsets of current chunk in current iteration
key=key,
value=value,
)
else:
# If current rank has no tokens to process, create empty tensors
key = torch.empty(0,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
value = torch.empty(0,
self.num_heads,
self.head_size,
dtype=query.dtype,
device=query.device)
seq_lens_current_chunk_rank = torch.zeros(
(len(seq_lens_current_chunk), ),
dtype=torch.int32,
device=query.device)
for req_idx in range(num_requests):
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
if i >= len(local_chunked_kv_lens[req_idx]):
continue
context_starts_rank[req_idx] += local_chunked_kv_lens[req_idx][i][
self.pcp_rank][self.dcp_rank]
if self.dcp_size > 1:
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank)
assert len(req_dcp_sizes) == num_requests and all(
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
total_toks = np.sum(np.array(req_dcp_sizes))
kv_local = torch.cat([key, value], dim=-1)
head_dim = kv_local.size(-1)
kv_full = torch.empty((total_toks, num_heads, head_dim),
device=query.device,
dtype=query.dtype)
kv_full_list = [None for _ in range(self.dcp_size)]
dist.all_gather_object(kv_full_list,
kv_local,
group=self.dcp_group)
kv_full_list = [
kv for kv in kv_full_list if kv is not None and kv.numel() > 0
]
if len(kv_full_list) > 0:
kv_full = torch.cat(kv_full_list, dim=0)
key, value = kv_full.split([head_size, head_size], dim=-1)
if total_toks == 0:
return key, value, None
seq_lens_current_chunk_rank = torch.tensor(
np.sum(np.array(req_dcp_sizes), axis=1),
dtype=torch.int32,
device=query.device) # [reqs]
return key, value, seq_lens_current_chunk_rank
def forward(
self,
layer: AttentionLayer,
@@ -1162,7 +1541,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
if self.pcp_size * self.dcp_size > 1:
intermediate_output = self._forward_pcp_dcp(
query, key, value, attn_metadata, output)
query, key, value, kv_cache, attn_metadata, output)
elif attn_type == AttentionType.ENCODER_ONLY:
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
@@ -1185,7 +1564,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
intermediate_output = self._forward_prefill_no_cache(
query, key, value, attn_metadata, output, num_tokens)
elif attn_metadata.attn_state == \
AscendAttentionState.PrefillCacheHit:
AscendAttentionState.PrefillCacheHit:
intermediate_output = self._forward_prefill_cache_hit(
query, attn_metadata, output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: