[feature] chunkprefill support pcp&dcp (#3801)
### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI tests passed with self-test
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
@@ -37,6 +37,8 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
extract_req_dcp_by_chunk_pcp,
|
||||
filter_chunked_req_indices,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
@@ -52,6 +54,7 @@ if prefill_context_parallel_enable():
|
||||
get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
|
||||
# isort: on
|
||||
|
||||
|
||||
@@ -155,9 +158,23 @@ class AscendPCPMetadata:
|
||||
|
||||
@dataclass
|
||||
class AscendMetadataForPrefill:
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
actual_chunk_seq_lengths: list[int]
|
||||
mask_for_non_zero_chunk: Optional[list[bool]] = None
|
||||
max_chunk_num: int = 0
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||
Optional[list[int]]]]]]]] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
pcp_allgather_restore_idx: Optional[List[int]] = None
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
block_tables: torch.Tensor = None
|
||||
actual_seq_lengths_q: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -165,6 +182,7 @@ class AscendMetadataForDecode:
|
||||
""" Decode Specific Metadata for Ascend"""
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
batch_seq_mask: torch.Tensor = None
|
||||
block_tables: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -237,13 +255,10 @@ class AscendAttentionMetadataBuilder:
|
||||
self.max_num_blocks_per_req = cdiv(
|
||||
self.model_config.max_model_len,
|
||||
AscendAttentionBackend.get_supported_block_size()[0])
|
||||
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
|
||||
'decode_max_num_seqs', 0)
|
||||
max_num_seqs = max(vllm_config.scheduler_config.max_num_seqs,
|
||||
decode_max_num_seqs)
|
||||
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
self.batch_seq_mask_buf = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
@@ -263,6 +278,27 @@ class AscendAttentionMetadataBuilder:
|
||||
|
||||
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||
self.block_size - 1) // self.block_size
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * self.model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reorder_batch(self, input_batch,
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
@@ -279,9 +315,8 @@ class AscendAttentionMetadataBuilder:
|
||||
num_reqs
|
||||
+ 1]
|
||||
|
||||
decode_threshold = 1
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
@@ -302,6 +337,7 @@ class AscendAttentionMetadataBuilder:
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
if attn_state == AscendAttentionState.DecodeOnly and \
|
||||
common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||
@@ -338,16 +374,35 @@ class AscendAttentionMetadataBuilder:
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
pcp_metadata = None
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
if common_long_seq_metadata is not None:
|
||||
decode_metadata = None
|
||||
if common_long_seq_metadata is not None:
|
||||
chunked_context_metadata = None
|
||||
if num_prefills > 0:
|
||||
query_lens = query_lens[num_decode_tokens:]
|
||||
context_lens_cpu = num_computed_tokens_cpu[
|
||||
num_decodes:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk
|
||||
kv_inverse_idx_for_chunk = torch.argsort(
|
||||
cp_kv_recover_idx_for_chunk.to(torch.float32)
|
||||
) if cp_kv_recover_idx_for_chunk is not None else None
|
||||
chunked_context_metadata = \
|
||||
AscendMetadataForPrefill.ChunkedContextMetadata(
|
||||
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
|
||||
mask_for_non_zero_chunk=common_long_seq_metadata.mask_for_non_zero_chunk,
|
||||
local_chunked_kv_lens=common_long_seq_metadata.local_chunked_kv_lens,
|
||||
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
||||
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
|
||||
max_chunk_num=common_long_seq_metadata.max_chunk_num
|
||||
)
|
||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
||||
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
|
||||
pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
if pcp_size > 1:
|
||||
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0],
|
||||
dim=0).tolist()
|
||||
@@ -355,6 +410,7 @@ class AscendAttentionMetadataBuilder:
|
||||
head_attn_nomask_seqlens[1], dim=0).tolist()
|
||||
tail_attn_nomask_seqlens = torch.cumsum(
|
||||
tail_attn_nomask_seqlens[1], dim=0).tolist()
|
||||
|
||||
pcp_metadata = AscendPCPMetadata(
|
||||
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||
@@ -371,16 +427,17 @@ class AscendAttentionMetadataBuilder:
|
||||
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
|
||||
prefill_metadata = AscendMetadataForPrefill(
|
||||
pcp_metadata=pcp_metadata,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||
pcp_allgather_restore_idx
|
||||
if common_long_seq_metadata is not None else None)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
if common_long_seq_metadata is not None:
|
||||
prefill_metadata = AscendMetadataForPrefill(
|
||||
pcp_metadata=pcp_metadata,
|
||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||
pcp_allgather_restore_idx
|
||||
if common_long_seq_metadata is not None else None,
|
||||
chunked_context=chunked_context_metadata,
|
||||
block_tables=block_table[num_decodes:],
|
||||
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0))
|
||||
|
||||
if num_decodes > 0:
|
||||
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||
assert num_computed_tokens_of_pcp_dcp is not None
|
||||
num_computed_tokens_array = np.array(
|
||||
@@ -397,7 +454,7 @@ class AscendAttentionMetadataBuilder:
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
|
||||
batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask.
|
||||
shape[0]],
|
||||
)
|
||||
block_tables=block_table[:num_decodes])
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -751,7 +808,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
k_mask: torch.Tensor,
|
||||
v_mask: torch.Tensor,
|
||||
kv_seqlens_mask: List[int],
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
mask: torch.Tensor,
|
||||
attn_metadata) -> torch.Tensor:
|
||||
# nomask Attention
|
||||
if k_nomask is not None:
|
||||
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
@@ -786,30 +844,42 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=kv_seqlens_mask,
|
||||
actual_seq_lengths=q_seqlens)
|
||||
|
||||
# update
|
||||
output = attn_out_mask
|
||||
attn_lse = attn_lse_mask
|
||||
if k_nomask is not None:
|
||||
T = attn_out_mask.shape[0]
|
||||
N = attn_out_mask.shape[1]
|
||||
D = attn_out_mask.shape[2]
|
||||
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None:
|
||||
output = self._npu_attn_out_lse_update(attn_lse_mask,
|
||||
attn_lse_nomask,
|
||||
attn_out_mask,
|
||||
attn_out_nomask)
|
||||
attn_lse = None
|
||||
else:
|
||||
output, attn_lse = self._update_out_and_lse(
|
||||
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
|
||||
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
|
||||
|
||||
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
|
||||
attn_out_mask, attn_lse_mask)
|
||||
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
|
||||
attn_out_nomask, attn_lse_nomask)
|
||||
attn_out_mask = attn_out_mask.to(torch.float32)
|
||||
attn_out_nomask = attn_out_nomask.to(torch.float32)
|
||||
attn_lse_mask = attn_lse_mask.to(torch.float32)
|
||||
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
|
||||
|
||||
attn_output = [attn_out_nomask, attn_out_mask]
|
||||
attn_lse = [attn_lse_nomask, attn_lse_mask]
|
||||
update_type = 0
|
||||
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
|
||||
update_type)
|
||||
output = output.view(T, N, D)
|
||||
return output, attn_lse
|
||||
|
||||
def _npu_attn_out_lse_update(self, attn_lse_mask, attn_lse_nomask,
|
||||
attn_out_mask, attn_out_nomask):
|
||||
T = attn_out_mask.shape[0]
|
||||
N = attn_out_mask.shape[1]
|
||||
D = attn_out_mask.shape[2]
|
||||
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
|
||||
attn_out_mask, attn_lse_mask)
|
||||
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
|
||||
attn_out_nomask, attn_lse_nomask)
|
||||
attn_out_mask = attn_out_mask.to(torch.float32)
|
||||
attn_out_nomask = attn_out_nomask.to(torch.float32)
|
||||
attn_lse_mask = attn_lse_mask.to(torch.float32)
|
||||
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
|
||||
attn_output = [attn_out_nomask, attn_out_mask]
|
||||
attn_lse = [attn_lse_nomask, attn_lse_mask]
|
||||
update_type = 0
|
||||
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
|
||||
update_type)
|
||||
output = output.view(T, N, D)
|
||||
return output
|
||||
|
||||
def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
|
||||
@@ -831,7 +901,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
|
||||
# 1. Attention calculation in the first half of Q in load balancing
|
||||
output_head = self._attention_with_nomask_and_mask(
|
||||
output_heads, lse_heads = self._attention_with_nomask_and_mask(
|
||||
q=torch.index_select(query, 0, q_head_idx),
|
||||
q_seqlens=attn_mask_seqlens,
|
||||
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
|
||||
@@ -842,12 +912,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
|
||||
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
|
||||
kv_seqlens_mask=attn_mask_seqlens,
|
||||
mask=mask)
|
||||
mask=mask,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# 2. the Attention calculation in the latter half of Q in load balancing
|
||||
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
|
||||
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
|
||||
output_tail = self._attention_with_nomask_and_mask(
|
||||
output_tails, lse_tails = self._attention_with_nomask_and_mask(
|
||||
q=torch.index_select(query, 0, q_tail_idx),
|
||||
q_seqlens=attn_mask_seqlens,
|
||||
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
|
||||
@@ -856,13 +927,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
|
||||
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
|
||||
kv_seqlens_mask=attn_mask_seqlens,
|
||||
mask=mask)
|
||||
mask=mask,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# 3. Combine the output of the first half and second half.
|
||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||
output = torch.index_select(
|
||||
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
|
||||
return output
|
||||
torch.cat([output_heads, output_tails], dim=0), 0, q_full_idx)
|
||||
attn_lse = None
|
||||
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
|
||||
attn_lse = torch.index_select(
|
||||
torch.cat([lse_heads, lse_tails], dim=0), 0, q_full_idx)
|
||||
return output, attn_lse
|
||||
|
||||
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
||||
attn_lse: torch.Tensor) -> torch.Tensor:
|
||||
@@ -928,7 +1003,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
'softmax_lse_flag':
|
||||
True,
|
||||
'block_table':
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.decode_meta.block_tables,
|
||||
'block_size':
|
||||
self.key_cache.shape[1],
|
||||
'actual_seq_lengths_kv':
|
||||
@@ -1029,8 +1104,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_out = self._npu_attention_update(attn_out_lse_list)
|
||||
return attn_out
|
||||
|
||||
def _update_out_and_lse(self, out_list: torch.Tensor,
|
||||
lse_list: torch.Tensor) -> torch.Tensor:
|
||||
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
|
||||
Args:
|
||||
out_list: shape = [N, batch_size, num_heads, head_size]
|
||||
lse_list: shape = [N, batch_size, num_heads, 1]
|
||||
Returns:
|
||||
out_final: shape = [batch_size, num_heads, head_size]
|
||||
lse_final: shape = [batch_size, num_heads, 1]
|
||||
"""
|
||||
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
|
||||
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
|
||||
dim=0)
|
||||
return out_final, lse_final
|
||||
|
||||
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||
value: torch.Tensor, kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor) -> torch.Tensor:
|
||||
assert attn_metadata is not None
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
@@ -1043,32 +1134,320 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
decode_query, attn_metadata)
|
||||
output[:num_decode_tokens] = output_decode
|
||||
if has_prefill:
|
||||
prefill_query = query[num_decode_tokens:]
|
||||
assert attn_metadata.prefill is not None
|
||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||
prefill_query = query[
|
||||
num_decode_tokens:num_actual_tokens_pcp_padded]
|
||||
key = key[self.pcp_size * num_decode_tokens:]
|
||||
value = value[self.pcp_size * num_decode_tokens:]
|
||||
if self.pcp_size > 1:
|
||||
output_prefill = self._forward_prefill_cp(
|
||||
# Scenario of Enabling PCP or PCP&DCP
|
||||
attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp(
|
||||
prefill_query, key, value, attn_metadata)
|
||||
else:
|
||||
max_prefill_seq_len = attn_metadata.seq_lens[
|
||||
attn_metadata.num_decode_tokens:].max().item()
|
||||
if attn_metadata.attn_mask is not None:
|
||||
attn_metadata.attn_mask = attn_metadata.attn_mask[:
|
||||
max_prefill_seq_len, :
|
||||
max_prefill_seq_len]
|
||||
else:
|
||||
ValueError("Attn_metadata.attn_mask is required")
|
||||
seq_lens_back = attn_metadata.seq_lens
|
||||
attn_metadata.seq_lens = attn_metadata.seq_lens[
|
||||
attn_metadata.num_decode_tokens:]
|
||||
output_prefill = self._forward_prefill_no_cache(
|
||||
prefill_query, key, value, attn_metadata,
|
||||
output[num_decode_tokens:], prefill_query.shape[0])
|
||||
attn_metadata.seq_lens = seq_lens_back
|
||||
output[num_decode_tokens:output_prefill.shape[0] +
|
||||
num_decode_tokens] = output_prefill
|
||||
# Scenario of Enabling DCP Individually
|
||||
attn_output_prefill, attn_lse_prefill = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
prefill_query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND",
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=attn_metadata.prefill.
|
||||
actual_seq_lengths_q,
|
||||
actual_seq_lengths=attn_metadata.prefill.
|
||||
actual_seq_lengths_q)
|
||||
|
||||
self._process_chunk_prefill(attn_output_prefill, attn_lse_prefill,
|
||||
kv_cache, prefill_query, attn_metadata)
|
||||
output[num_decode_tokens:attn_output_prefill.shape[0] +
|
||||
num_decode_tokens] = attn_output_prefill
|
||||
return output
|
||||
|
||||
def _process_chunk_prefill(self, current_attn_output_prefill,
|
||||
current_attn_lse_prefill, kv_cache,
|
||||
prefill_query, attn_metadata):
|
||||
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
|
||||
prefill_query_all = self._prefill_query_all_gather(
|
||||
attn_metadata, prefill_query)
|
||||
attn_output_full_chunk, attn_lse_full_chunk = self._compute_prefill_context(
|
||||
prefill_query_all, kv_cache, attn_metadata)
|
||||
self._update_chunk_attn_out_lse_with_current_attn_out_lse(
|
||||
current_attn_output_prefill, current_attn_lse_prefill,
|
||||
attn_output_full_chunk, attn_lse_full_chunk, prefill_query,
|
||||
attn_metadata)
|
||||
|
||||
def _update_chunk_attn_out_lse_with_current_attn_out_lse(
|
||||
self, current_attn_output_prefill, current_attn_lse_prefill,
|
||||
attn_output_full_chunk, attn_lse_full_chunk, prefill_query,
|
||||
attn_metadata):
|
||||
if self.pcp_size > 1:
|
||||
inverse_idx = attn_metadata.prefill.chunked_context.kv_inverse_idx_for_chunk
|
||||
attn_output_full_chunk = torch.index_select(
|
||||
attn_output_full_chunk, 0, inverse_idx)
|
||||
attn_lse_full_chunk = torch.index_select(attn_lse_full_chunk, 0,
|
||||
inverse_idx)
|
||||
num_tokens = prefill_query.size(0)
|
||||
attn_output_full_chunk = attn_output_full_chunk[
|
||||
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
||||
attn_lse_full_chunk = attn_lse_full_chunk[
|
||||
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
||||
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
|
||||
seq_len = attn_metadata.query_lens.detach().clone()
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
seq_len,
|
||||
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
|
||||
|
||||
attn_output_prefill_filtered = current_attn_output_prefill[
|
||||
filtered_indices, :, :]
|
||||
attn_lse_prefill_filtered = current_attn_lse_prefill[
|
||||
filtered_indices, :, :]
|
||||
attn_output_full_chunk = attn_output_full_chunk[filtered_indices, :, :]
|
||||
attn_lse_full_chunk = attn_lse_full_chunk[filtered_indices, :, :]
|
||||
|
||||
attn_output_filtered = self._npu_attn_out_lse_update(
|
||||
attn_lse_prefill_filtered, attn_lse_full_chunk,
|
||||
attn_output_prefill_filtered, attn_output_full_chunk)
|
||||
current_attn_output_prefill[
|
||||
filtered_indices, :, :] = attn_output_filtered.to(
|
||||
current_attn_output_prefill.dtype)
|
||||
|
||||
def _prefill_query_all_gather(self, attn_metadata, prefill_query):
|
||||
prefill_query_all = get_pcp_group().all_gather(prefill_query.contiguous(),
|
||||
0) \
|
||||
if self.pcp_size > 1 else prefill_query
|
||||
prefill_query_all = torch.index_select(prefill_query_all,
|
||||
0,
|
||||
attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk) \
|
||||
if self.pcp_size > 1 else prefill_query_all
|
||||
return prefill_query_all
|
||||
|
||||
def _compute_prefill_context(self, query: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata):
|
||||
assert len(kv_cache) > 1
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.prefill is not None
|
||||
assert attn_metadata.prefill.chunked_context is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
local_chunked_kv_lens = attn_metadata.prefill.chunked_context.local_chunked_kv_lens
|
||||
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
|
||||
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
|
||||
|
||||
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
|
||||
|
||||
iters = max_chunk_num
|
||||
# Keep the causal mask; do not override to all-ones. [req_id][chunk_id][cp-rank][dcp_rank]
|
||||
context_starts_rank = None
|
||||
|
||||
prefix_output_list = []
|
||||
prefix_lse_list = []
|
||||
for i in range(iters):
|
||||
key, value, seq_lens_current_chunk_rank = self._load_kv_for_chunk(
|
||||
attn_metadata, kv_cache, context_starts_rank, i,
|
||||
local_chunked_kv_lens, prefill_metadata, query)
|
||||
|
||||
# 2. Attention computation
|
||||
if seq_lens_current_chunk_rank is None or torch.all(
|
||||
seq_lens_current_chunk_rank == 0).item():
|
||||
prefix_output = torch.full(
|
||||
(query.size(0), self.num_heads, self.head_size),
|
||||
fill_value=0,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
prefix_lse = torch.full((query.size(0), self.num_heads, 1),
|
||||
fill_value=0,
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
else:
|
||||
|
||||
actual_seq_lengths_kv = torch.cumsum(
|
||||
seq_lens_current_chunk_rank, dim=0).tolist()
|
||||
prefix_output, prefix_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="TND", #
|
||||
atten_mask=None,
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_lse_flag=True,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
||||
actual_chunk_seq_lengths)
|
||||
prefix_output_list.append(prefix_output)
|
||||
prefix_lse_list.append(prefix_lse)
|
||||
|
||||
# 3. update attn-out & lse
|
||||
prefix_output, prefix_lse = self._update_attn_out_lse_in_chunks(
|
||||
prefix_output_list, prefix_lse_list)
|
||||
|
||||
self._update_attn_out_lse_in_pcp(attn_metadata, prefix_output,
|
||||
prefix_lse)
|
||||
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _update_attn_out_lse_in_chunks(self, prefix_output_list,
|
||||
prefix_lse_list):
|
||||
# update output and lse
|
||||
if len(prefix_output_list) > 1:
|
||||
prefix_output, prefix_lse = self._update_out_and_lse(
|
||||
torch.stack(prefix_output_list, dim=0),
|
||||
torch.stack(prefix_lse_list, dim=0))
|
||||
else:
|
||||
prefix_output = prefix_output_list[0]
|
||||
prefix_lse = prefix_lse_list[0]
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _update_attn_out_lse_in_pcp(self, attn_metadata, prefix_output,
|
||||
prefix_lse):
|
||||
# CP dimension all_gather and fusion
|
||||
if self.pcp_size > 1:
|
||||
# filter non-zero chunk part of prefix_output
|
||||
current_seq_lens = attn_metadata.query_lens.detach().clone()
|
||||
current_seq_lens.mul_(self.pcp_size) # q_full
|
||||
current_seq_lens_cpu = current_seq_lens.cpu()
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
current_seq_lens_cpu,
|
||||
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
|
||||
prefix_output_filtered = prefix_output[filtered_indices, :, :]
|
||||
prefix_lse_filtered = prefix_lse[filtered_indices, :, :]
|
||||
|
||||
out_lse_local = torch.cat(
|
||||
[prefix_output_filtered, prefix_lse_filtered], dim=-1)
|
||||
attn_out_lse_list = [
|
||||
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
|
||||
]
|
||||
dist.all_gather(attn_out_lse_list,
|
||||
out_lse_local,
|
||||
group=self.pcp_group)
|
||||
|
||||
attn_out_lse_allgather = torch.stack(
|
||||
attn_out_lse_list,
|
||||
dim=0) # [pcp, batch_size, num_heads, head_size+1]
|
||||
attn_out_allgather, attn_lse_allgather = torch.split(
|
||||
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
|
||||
prefix_output_filtered, prefix_lse_filtered = self._update_out_and_lse(
|
||||
attn_out_allgather, attn_lse_allgather)
|
||||
|
||||
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
|
||||
prefix_output.dtype)
|
||||
prefix_lse[filtered_indices, :, :] = prefix_lse_filtered.to(
|
||||
prefix_lse.dtype)
|
||||
|
||||
def _load_kv_for_chunk(self, attn_metadata, kv_cache, context_starts_rank,
|
||||
i, local_chunked_kv_lens, prefill_metadata, query):
|
||||
|
||||
cache_key = kv_cache[0]
|
||||
cache_value = kv_cache[1]
|
||||
num_heads = cache_key.size(2)
|
||||
head_size = kv_cache[0].size(-1)
|
||||
|
||||
# 1. Load current query's history key-value
|
||||
seq_lens_current_chunk = attn_metadata.query_lens.detach().clone()
|
||||
num_requests = len(seq_lens_current_chunk)
|
||||
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||
context_starts_rank = torch.zeros(
|
||||
num_requests, dtype=torch.int32, device=query.device
|
||||
) if context_starts_rank is None else context_starts_rank
|
||||
# Calculate tokens each rank should process per request
|
||||
seq_lens_current_chunk_rank = torch.zeros_like(seq_lens_current_chunk,
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
total_toks = 0
|
||||
for req_idx in range(num_requests):
|
||||
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||
continue
|
||||
n_computed_acc = local_chunked_kv_lens[req_idx][i]
|
||||
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
|
||||
seq_lens_current_chunk_rank[req_idx] = n_computed_acc[
|
||||
self.pcp_rank][self.dcp_rank]
|
||||
if total_toks > 0:
|
||||
key = torch.empty(total_toks,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
value = torch.empty(total_toks,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_key,
|
||||
cache_value,
|
||||
attn_metadata.prefill.block_tables,
|
||||
seq_lens_current_chunk_rank.to(query.device),
|
||||
seq_starts=
|
||||
context_starts_rank, # slot offsets of current chunk in current iteration
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
else:
|
||||
# If current rank has no tokens to process, create empty tensors
|
||||
key = torch.empty(0,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
value = torch.empty(0,
|
||||
self.num_heads,
|
||||
self.head_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
seq_lens_current_chunk_rank = torch.zeros(
|
||||
(len(seq_lens_current_chunk), ),
|
||||
dtype=torch.int32,
|
||||
device=query.device)
|
||||
for req_idx in range(num_requests):
|
||||
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||
continue
|
||||
context_starts_rank[req_idx] += local_chunked_kv_lens[req_idx][i][
|
||||
self.pcp_rank][self.dcp_rank]
|
||||
if self.dcp_size > 1:
|
||||
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
|
||||
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank)
|
||||
|
||||
assert len(req_dcp_sizes) == num_requests and all(
|
||||
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
|
||||
total_toks = np.sum(np.array(req_dcp_sizes))
|
||||
kv_local = torch.cat([key, value], dim=-1)
|
||||
head_dim = kv_local.size(-1)
|
||||
kv_full = torch.empty((total_toks, num_heads, head_dim),
|
||||
device=query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
kv_full_list = [None for _ in range(self.dcp_size)]
|
||||
dist.all_gather_object(kv_full_list,
|
||||
kv_local,
|
||||
group=self.dcp_group)
|
||||
kv_full_list = [
|
||||
kv for kv in kv_full_list if kv is not None and kv.numel() > 0
|
||||
]
|
||||
|
||||
if len(kv_full_list) > 0:
|
||||
kv_full = torch.cat(kv_full_list, dim=0)
|
||||
key, value = kv_full.split([head_size, head_size], dim=-1)
|
||||
if total_toks == 0:
|
||||
return key, value, None
|
||||
seq_lens_current_chunk_rank = torch.tensor(
|
||||
np.sum(np.array(req_dcp_sizes), axis=1),
|
||||
dtype=torch.int32,
|
||||
device=query.device) # [reqs]
|
||||
return key, value, seq_lens_current_chunk_rank
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@@ -1162,7 +1541,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
intermediate_output = self._forward_pcp_dcp(
|
||||
query, key, value, attn_metadata, output)
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
@@ -1185,7 +1564,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
intermediate_output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
intermediate_output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
@@ -27,11 +28,14 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills,
|
||||
trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.attention.utils import (
|
||||
AscendCommonAttentionMetadata, extract_req_dcp_by_chunk_pcp,
|
||||
filter_chunked_req_indices, maybe_save_kv_layer_to_connector,
|
||||
split_decodes_and_prefills, trans_rope_weight, transdata,
|
||||
wait_for_kv_layer_from_connector)
|
||||
# isort: on
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
@@ -111,6 +115,10 @@ class AscendMLAPrefillMetadata:
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
chunk_seq_lens_npu: torch.Tensor
|
||||
mask_for_non_zero_chunk: Optional[list[bool]] = None
|
||||
max_chunk_num: int = 0
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||
Optional[list[int]]]]]]]] = None
|
||||
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: torch.Tensor
|
||||
@@ -125,6 +133,7 @@ class AscendMLAPrefillMetadata:
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -347,6 +356,10 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
||||
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
|
||||
cp_kv_recover_idx_for_chunk = long_seq_metadata.cp_kv_recover_idx_for_chunk if long_seq_metadata else None
|
||||
local_chunked_kv_lens = long_seq_metadata.local_chunked_kv_lens if long_seq_metadata else None
|
||||
mask_for_non_zero_chunk = long_seq_metadata.mask_for_non_zero_chunk if long_seq_metadata else None
|
||||
max_chunk_num = long_seq_metadata.max_chunk_num if long_seq_metadata else 0
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
@@ -359,14 +372,15 @@ class AscendMLAMetadataBuilder:
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
if num_actual_tokens_pcp_padded is None:
|
||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens_pcp_padded]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens_pcp_padded].long(
|
||||
)
|
||||
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = model.model.layers[
|
||||
@@ -408,7 +422,8 @@ class AscendMLAMetadataBuilder:
|
||||
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
||||
tail_attn_nomask_seqlens,
|
||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask
|
||||
if long_seq_metadata else None,
|
||||
pcp_allgather_restore_idx=long_seq_metadata.
|
||||
pcp_allgather_restore_idx if long_seq_metadata else None)
|
||||
|
||||
@@ -452,6 +467,9 @@ class AscendMLAMetadataBuilder:
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
local_chunked_kv_lens=local_chunked_kv_lens,
|
||||
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
|
||||
max_chunk_num=max_chunk_num,
|
||||
)
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
@@ -474,7 +492,7 @@ class AscendMLAMetadataBuilder:
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
pcp_metadata=pcp_metadata,
|
||||
)
|
||||
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
@@ -887,8 +905,26 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
||||
return prefix_output, prefix_lse
|
||||
local_chunked_kv_lens = prefill_metadata.chunked_context.local_chunked_kv_lens
|
||||
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
|
||||
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
|
||||
|
||||
if self.pcp_size > 1:
|
||||
prefix_output = torch.zeros(q_nope.shape[0],
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
prefix_lse = torch.zeros(self.num_heads,
|
||||
q_pe.shape[0],
|
||||
dtype=torch.float32,
|
||||
device=q_pe.device)
|
||||
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
iters = max_chunk_num
|
||||
|
||||
current_seq_len = torch.tensor(prefill_metadata.query_lens,
|
||||
dtype=torch.int32)
|
||||
@@ -896,60 +932,305 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||
num_heads = cache_k_pe.size(2)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
||||
# token -> request mapping for building per-token masks when CP>1
|
||||
seq_len1 = torch.tensor(prefill_metadata.query_lens,
|
||||
dtype=torch.int32,
|
||||
device=q_nope.device)
|
||||
seq_len1.mul_(
|
||||
self.pcp_size) # q_full: already padded, divisible by cp_size
|
||||
|
||||
# Select mask: prefer CP prefill mask from metadata; fallback to cached prefill_mask; create if needed.
|
||||
mask_local = None
|
||||
if attn_metadata is not None and attn_metadata.prefill is not None and \
|
||||
attn_metadata.prefill.pcp_metadata is not None and attn_metadata.prefill.pcp_metadata.pcp_prefill_mask is not None:
|
||||
mask_local = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
else:
|
||||
mask_local = self.prefill_mask
|
||||
if mask_local is None:
|
||||
mask_local = torch.triu(
|
||||
torch.ones(512,
|
||||
512,
|
||||
device=q_nope.device,
|
||||
dtype=q_nope.dtype), 1)
|
||||
self.prefill_mask = mask_local
|
||||
|
||||
# Keep the causal mask; do not override to all-ones.
|
||||
context_starts_rank = None
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
## DCP mode: each rank processes its own (cp,dcp) historical context slice per request dimension
|
||||
num_requests = len(seq_len1)
|
||||
assert num_requests == len(local_chunked_kv_lens)
|
||||
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||
context_starts_rank = torch.zeros(
|
||||
num_requests, dtype=torch.int32, device=q_nope.device
|
||||
) if context_starts_rank is None else context_starts_rank
|
||||
|
||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||
i]
|
||||
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
||||
i]
|
||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||
kv_c_normed = torch.empty(toks,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(toks,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
## Calculate tokens each rank should process per request
|
||||
seq_len2_rank = torch.zeros_like(seq_len1, dtype=torch.int32)
|
||||
total_toks = 0
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
context_seq_len_npu,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
for req_idx in range(num_requests):
|
||||
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||
continue
|
||||
n_computed_acc = local_chunked_kv_lens[req_idx][i]
|
||||
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
|
||||
seq_len2_rank[req_idx] = n_computed_acc[self.pcp_rank][
|
||||
self.dcp_rank]
|
||||
|
||||
if total_toks > 0:
|
||||
kv_c_normed = torch.empty(total_toks,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(total_toks,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
seq_len2_rank.to(q_nope.device),
|
||||
seq_starts=
|
||||
context_starts_rank, # slot offsets of current chunk in current iteration
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
seq_len2 = seq_len2_rank.to(q_nope.device)
|
||||
else:
|
||||
# If current rank has no tokens to process, create empty tensors
|
||||
kv_c_normed = torch.empty(0,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(0,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
seq_len2 = torch.zeros((len(seq_len1), ),
|
||||
dtype=torch.int32,
|
||||
device=q_nope.device)
|
||||
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
|
||||
for req_idx in range(num_requests):
|
||||
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||
continue
|
||||
context_starts_rank[req_idx] += local_chunked_kv_lens[
|
||||
req_idx][i][self.pcp_rank][self.dcp_rank]
|
||||
else:
|
||||
# Original logic: ChunkPrefill-only mode
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||
i]
|
||||
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
||||
i]
|
||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||
|
||||
kv_c_normed = torch.empty(toks,
|
||||
num_heads,
|
||||
latent_kv_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
k_pe = torch.empty(toks,
|
||||
num_heads,
|
||||
rope_dim,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
context_seq_len_npu,
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
|
||||
kv_c_normed = kv_c_normed.squeeze()
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=self.prefill_mask,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
if self.dcp_size > 1:
|
||||
# DCP mode: first all_gather within DCP group, let each rank in CP group share complete sequence blocks
|
||||
# Step 1: DCP all_gather latent
|
||||
kv_c_k_pe_local = torch.cat(
|
||||
[kv_c_normed, k_pe.squeeze()],
|
||||
dim=-1) # [local_toks, latent_dim + rope_dim]
|
||||
|
||||
# Step 2: use all_gather_into_tensor_uneven (gather + cat)
|
||||
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
|
||||
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank
|
||||
) # need to know num tokens of each rank in dcp group before all_gather # [reqs, dcp]
|
||||
assert len(req_dcp_sizes) == num_requests and all(
|
||||
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
|
||||
total_toks = np.sum(np.array(req_dcp_sizes))
|
||||
latent_rope_dim = kv_c_k_pe_local.size(-1)
|
||||
kv_c_k_pe_full = torch.empty((total_toks, latent_rope_dim),
|
||||
device=kv_c_k_pe_local.device,
|
||||
dtype=kv_c_k_pe_local.dtype)
|
||||
|
||||
kv_c_k_pe_full_list = [None for _ in range(self.dcp_size)]
|
||||
dist.all_gather_object(kv_c_k_pe_full_list,
|
||||
kv_c_k_pe_local,
|
||||
group=self.dcp_group)
|
||||
kv_c_k_pe_full_list = [
|
||||
kv_c_k_pe for kv_c_k_pe in kv_c_k_pe_full_list
|
||||
if kv_c_k_pe is not None and kv_c_k_pe.numel() > 0
|
||||
]
|
||||
if len(kv_c_k_pe_full_list) > 0:
|
||||
kv_c_k_pe_full = torch.cat(kv_c_k_pe_full_list, dim=0)
|
||||
if len(kv_c_k_pe_full.shape) == 1:
|
||||
assert total_toks == 1
|
||||
kv_c_k_pe_full = kv_c_k_pe_full.unsqueeze(0)
|
||||
assert kv_c_k_pe_full.shape[
|
||||
0] == total_toks and kv_c_k_pe_full.shape[
|
||||
1] == latent_rope_dim
|
||||
kv_c_normed_full, k_pe_full = torch.split(
|
||||
kv_c_k_pe_full, [latent_kv_dim, rope_dim], dim=-1)
|
||||
|
||||
# Step 3: process complete sequence with TP projection to get current rank's head slice
|
||||
# Case that no kv_cache has been stored on this CP rank(after dcp all_gather), no need to do following computation.
|
||||
if total_toks == 0:
|
||||
continue
|
||||
kv_nope = self.kv_b_proj(kv_c_normed_full)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe_full.unsqueeze(1).expand((*k_nope.shape[:-1], -1))
|
||||
|
||||
seq_len2 = torch.tensor(np.sum(np.array(req_dcp_sizes),
|
||||
axis=1),
|
||||
dtype=torch.int32,
|
||||
device=q_nope.device) # [reqs]
|
||||
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
|
||||
else:
|
||||
# Non-DCP mode: use TP-split projection
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
|
||||
if self.pcp_size > 1:
|
||||
# Case that no kv_cache has been stored on this CP rank, no need to do following computation.
|
||||
if torch.all(seq_len2 == 0).item():
|
||||
continue
|
||||
# PCP mode: first compute this rank's contribution to the chunk
|
||||
if i == 0:
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask_local,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
continue
|
||||
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask_local,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
|
||||
else:
|
||||
assert not torch.all(context_seq_len == 0).item()
|
||||
# compute this chunk block then update prefix tensors to keep shapes consistent
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask_local,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
|
||||
# CP dimension all_gather and fusion
|
||||
if self.pcp_size > 1:
|
||||
# filter non-zero chunk part of prefix_output
|
||||
seq_len1_cpu = seq_len1.cpu()
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
seq_len1_cpu, mask_for_non_zero_chunk)
|
||||
prefix_output_filtered = prefix_output[filtered_indices, :, :]
|
||||
prefix_lse_filtered = prefix_lse[:, filtered_indices]
|
||||
|
||||
# normalize prefix LSE to [bs, heads, 1] for stable updates
|
||||
prefix_lse_filtered_bt = prefix_lse_filtered.permute(
|
||||
1, 0).unsqueeze(-1).contiguous(
|
||||
) if prefix_lse_filtered is not None else None
|
||||
out_lse_local = torch.cat(
|
||||
[prefix_output_filtered, prefix_lse_filtered_bt], dim=-1)
|
||||
out_lse_list = [
|
||||
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
|
||||
]
|
||||
dist.all_gather(out_lse_list, out_lse_local, group=self.pcp_group)
|
||||
prefix_output_filtered = None
|
||||
prefix_lse_filtered_bt = None
|
||||
for r in range(self.pcp_size):
|
||||
out_lse_r = out_lse_list[r]
|
||||
if torch.all(out_lse_r == 0).item():
|
||||
continue
|
||||
out_r, lse_r = torch.split(out_lse_r, [self.v_head_dim, 1],
|
||||
dim=-1)
|
||||
token_mask = torch.ones([out_r.size(0)],
|
||||
dtype=torch.uint8,
|
||||
device=out_r.device)
|
||||
prefix_output_filtered, prefix_lse_filtered_bt = self._update_out_and_lse(
|
||||
prefix_output_filtered, prefix_lse_filtered_bt, out_r,
|
||||
lse_r, token_mask)
|
||||
# convert lse back to [heads, bs]
|
||||
assert prefix_output_filtered is not None and prefix_lse_filtered_bt is not None
|
||||
prefix_lse_filtered = prefix_lse_filtered_bt.squeeze(-1).permute(
|
||||
1, 0).contiguous()
|
||||
|
||||
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
|
||||
prefix_output.dtype)
|
||||
prefix_lse[:, filtered_indices] = prefix_lse_filtered.to(
|
||||
prefix_lse.dtype)
|
||||
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _forward_prefill(
|
||||
@@ -1516,7 +1797,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||
|
||||
output_head = self._attention_with_mask_and_nomask(
|
||||
output_head, head_lse = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||
k_nope=k_nope,
|
||||
@@ -1528,7 +1809,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||
mask=mask)
|
||||
|
||||
output_tail = self._attention_with_mask_and_nomask(
|
||||
output_tail, tail_lse = self._attention_with_mask_and_nomask(
|
||||
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
||||
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
|
||||
k_nope=k_nope,
|
||||
@@ -1544,7 +1825,83 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
output = torch.index_select(
|
||||
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
|
||||
|
||||
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
|
||||
# Synchronize and reorder LSE for subsequent chunked context accumulation
|
||||
attn_lse = torch.cat([head_lse, tail_lse], dim=1)
|
||||
attn_lse = attn_lse[:, q_full_idx]
|
||||
|
||||
# Post-processing: keep [tokens, H, V] shape and perform chunked context accumulation if needed
|
||||
if attn_metadata.prefill is not None and \
|
||||
attn_metadata.prefill.chunked_context is not None:
|
||||
# q all_gather
|
||||
q_nope_full = get_pcp_group().all_gather(q_nope.contiguous(), 0)
|
||||
q_pe_full = get_pcp_group().all_gather(q_pe.contiguous(), 0)
|
||||
q_nope_full = torch.index_select(
|
||||
q_nope_full, 0,
|
||||
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||
q_pe_full = torch.index_select(
|
||||
q_pe_full, 0,
|
||||
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||
attn_output_pre = output.view(num_tokens, self.num_heads,
|
||||
self.v_head_dim)
|
||||
attn_output_pre_full, attn_lse_full = self._compute_prefill_context(
|
||||
q_nope_full,
|
||||
q_pe_full,
|
||||
kv_c_and_k_pe_cache,
|
||||
self.qk_rope_head_dim,
|
||||
attn_metadata,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
# reorder back && extract output + lse result of each cp rank
|
||||
inverse_idx = torch.argsort(
|
||||
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||
attn_output_pre_full = torch.index_select(attn_output_pre_full, 0,
|
||||
inverse_idx)
|
||||
attn_lse_full = torch.index_select(attn_lse_full, 1, inverse_idx)
|
||||
attn_output_pre_new = attn_output_pre_full[
|
||||
self.pcp_rank * num_tokens:(self.pcp_rank + 1) *
|
||||
num_tokens, :, :]
|
||||
attn_lse_new = attn_lse_full[:, self.pcp_rank *
|
||||
num_tokens:(self.pcp_rank + 1) *
|
||||
num_tokens]
|
||||
|
||||
# update(output_origin, output_new)
|
||||
assert attn_output_pre_new.shape == attn_output_pre.shape and attn_lse_new.shape == attn_lse.shape
|
||||
seq_len = torch.tensor(attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32)
|
||||
mask_for_non_zero_chunk = attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk
|
||||
filtered_indices = filter_chunked_req_indices(
|
||||
seq_len, mask_for_non_zero_chunk)
|
||||
attn_output_pre_filtered = attn_output_pre[filtered_indices, :, :]
|
||||
attn_lse_filtered = attn_lse[:, filtered_indices]
|
||||
attn_output_pre_new = attn_output_pre_new[filtered_indices, :, :]
|
||||
attn_lse_new = attn_lse_new[:, filtered_indices]
|
||||
|
||||
# normalize prefix LSE to [bs, heads, 1] for stable updates
|
||||
attn_lse_filtered = attn_lse_filtered.permute(1, 0).unsqueeze(-1)
|
||||
attn_lse_new = attn_lse_new.permute(1, 0).unsqueeze(-1)
|
||||
token_mask = torch.ones([attn_lse_new.size(0)],
|
||||
dtype=torch.uint8,
|
||||
device=attn_lse_new.device)
|
||||
attn_output_pre_filtered, attn_lse_filtered = self._update_out_and_lse(
|
||||
attn_output_pre_filtered, attn_lse_filtered,
|
||||
attn_output_pre_new, attn_lse_new, token_mask)
|
||||
# convert lse back to [heads, bs]
|
||||
attn_lse_filtered = attn_lse_filtered.squeeze(-1).permute(
|
||||
1, 0).contiguous()
|
||||
|
||||
attn_output_pre[
|
||||
filtered_indices, :, :] = attn_output_pre_filtered.to(
|
||||
attn_output_pre.dtype)
|
||||
attn_lse[:,
|
||||
filtered_indices] = attn_lse_filtered.to(attn_lse.dtype)
|
||||
|
||||
attn_output_pre = attn_output_pre.to(q_nope.dtype)
|
||||
output = attn_output_pre.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
else:
|
||||
output = output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
|
||||
return output
|
||||
|
||||
@@ -1588,7 +1945,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
# nomask
|
||||
if kv_nomask_idx.shape[0] == 0:
|
||||
return attn_output
|
||||
return attn_output, attn_lse
|
||||
|
||||
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
|
||||
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
|
||||
@@ -1611,7 +1968,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
calc_type="calc_type_default",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
return attn_output
|
||||
return attn_output, attn_lse
|
||||
|
||||
def _forward_decode_pcp_dcp(
|
||||
self,
|
||||
@@ -1788,3 +2145,33 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
||||
|
||||
return attn_out_lse_list
|
||||
|
||||
# TODO use update op to replace this
|
||||
def _update_out_and_lse(
|
||||
self,
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
):
|
||||
if out is None:
|
||||
out = block_out.to(torch.float32)
|
||||
lse = block_lse
|
||||
else:
|
||||
if mask is None:
|
||||
mask = torch.ones([block_out.size(0)],
|
||||
dtype=torch.uint8,
|
||||
device=block_out.device)
|
||||
out_mask = mask[:, None, None].expand_as(block_out)
|
||||
lse_mask = mask[:, None, None].expand_as(block_lse)
|
||||
block_out = block_out.to(torch.float32)
|
||||
out_without_update = out.clone()
|
||||
lse_without_update = lse.clone()
|
||||
|
||||
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
|
||||
lse = lse - F.logsigmoid(lse - block_lse)
|
||||
# mask
|
||||
out = torch.where(out_mask, out, out_without_update)
|
||||
lse = torch.where(lse_mask, lse, lse_without_update)
|
||||
return out, lse
|
||||
|
||||
@@ -14,10 +14,19 @@ from vllm.forward_context import ForwardContext, get_forward_context
|
||||
class AscendPrefillContextParallelMetadata:
|
||||
pcp_allgather_restore_idx: torch.Tensor = None
|
||||
|
||||
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
||||
|
||||
num_actual_tokens_pcp_padded: Optional[int] = None
|
||||
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||
|
||||
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]]]] = None
|
||||
|
||||
mask_for_non_zero_chunk: Optional[List[bool]] = None
|
||||
|
||||
max_chunk_num: int = 0
|
||||
|
||||
q_head_idx_tensor: torch.Tensor = None
|
||||
|
||||
q_tail_idx_tensor: torch.Tensor = None
|
||||
@@ -46,7 +55,7 @@ class AscendCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
@@ -106,6 +115,47 @@ class AscendCommonAttentionMetadata:
|
||||
AscendPrefillContextParallelMetadata] = None
|
||||
|
||||
|
||||
def extract_req_dcp_by_chunk_pcp(lst,
|
||||
chunk_idx,
|
||||
dcp_size,
|
||||
pcp_rank,
|
||||
fill_value=0):
|
||||
num_reqs = len(lst)
|
||||
results: List[List[int]] = []
|
||||
for i in range(num_reqs):
|
||||
if len(lst[i]) == 0 or chunk_idx >= len(lst[i]):
|
||||
# empty req or this req has no corresponding chunk, fill 0
|
||||
results.append([fill_value] * dcp_size)
|
||||
continue
|
||||
dcp_values = lst[i][chunk_idx][pcp_rank]
|
||||
results.append(dcp_values)
|
||||
return results
|
||||
|
||||
|
||||
def filter_chunked_req_indices(
|
||||
seq_len: torch.Tensor,
|
||||
mask_for_non_zero_chunk: Optional[List[bool]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
filter the reqs which are doing real chunk_prefill.
|
||||
|
||||
Args:
|
||||
seq_len: contains multi-req length: [req0_len, req1_len, ...]
|
||||
mask_for_non_zero_chunk: [True, False, True, False, ...]
|
||||
Returns:
|
||||
filtered_indices: the real chunked req's indices
|
||||
"""
|
||||
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(
|
||||
mask_for_non_zero_chunk)
|
||||
offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0)
|
||||
filtered_indices = torch.cat([
|
||||
torch.arange(offsets[i], offsets[i] + seq_len[i])
|
||||
for i in range(len(mask_for_non_zero_chunk))
|
||||
if mask_for_non_zero_chunk[i]
|
||||
])
|
||||
return filtered_indices
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
|
||||
Reference in New Issue
Block a user