[feature] chunkprefill support pcp&dcp (#3801)
### What this PR does / why we need it?
ChunkPrefill now can support Long Sequence Feature Pcp&Dcp
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI tests passed with self-test
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: Apocalypse990923-qshi <qiushixu@usc.edu>
Signed-off-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <tanwenqin@huawei.com>
Co-authored-by: Delphine-Nic <3834144971@qq.com>
This commit is contained in:
@@ -83,6 +83,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
self.mock_vllm_config.compilation_config.cudagraph_mode = None
|
self.mock_vllm_config.compilation_config.cudagraph_mode = None
|
||||||
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
|
self.mock_vllm_config.scheduler_config.max_num_seqs = 10
|
||||||
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
|
self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
|
||||||
|
self.mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
self.mock_device = 'cpu:0'
|
self.mock_device = 'cpu:0'
|
||||||
self.builder = AscendAttentionMetadataBuilder(None, None,
|
self.builder = AscendAttentionMetadataBuilder(None, None,
|
||||||
self.mock_vllm_config,
|
self.mock_vllm_config,
|
||||||
|
|||||||
@@ -484,6 +484,9 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
||||||
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
|
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
|
||||||
chunk_ctx.starts = [torch.tensor([0])]
|
chunk_ctx.starts = [torch.tensor([0])]
|
||||||
|
chunk_ctx.max_chunk_num = 1
|
||||||
|
chunk_ctx.mask_for_non_zero_chunk = [True]
|
||||||
|
chunk_ctx.local_chunked_kv_lens = [[[[8]]]]
|
||||||
|
|
||||||
prefill_meta = MagicMock()
|
prefill_meta = MagicMock()
|
||||||
prefill_meta.chunked_context = chunk_ctx
|
prefill_meta.chunked_context = chunk_ctx
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
|
extract_req_dcp_by_chunk_pcp,
|
||||||
|
filter_chunked_req_indices,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
@@ -52,6 +54,7 @@ if prefill_context_parallel_enable():
|
|||||||
get_prefill_context_model_parallel_rank,
|
get_prefill_context_model_parallel_rank,
|
||||||
get_prefill_context_model_parallel_world_size
|
get_prefill_context_model_parallel_world_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# isort: on
|
# isort: on
|
||||||
|
|
||||||
|
|
||||||
@@ -155,9 +158,23 @@ class AscendPCPMetadata:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AscendMetadataForPrefill:
|
class AscendMetadataForPrefill:
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkedContextMetadata:
|
||||||
|
actual_chunk_seq_lengths: list[int]
|
||||||
|
mask_for_non_zero_chunk: Optional[list[bool]] = None
|
||||||
|
max_chunk_num: int = 0
|
||||||
|
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||||
|
Optional[list[int]]]]]]]] = None
|
||||||
|
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||||
|
kv_inverse_idx_for_chunk: Optional[list[int]] = None
|
||||||
|
|
||||||
""" Prefill Specific Metadata for Ascend"""
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||||
pcp_allgather_restore_idx: Optional[List[int]] = None
|
pcp_allgather_restore_idx: Optional[List[int]] = None
|
||||||
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||||
|
block_tables: torch.Tensor = None
|
||||||
|
actual_seq_lengths_q: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -165,6 +182,7 @@ class AscendMetadataForDecode:
|
|||||||
""" Decode Specific Metadata for Ascend"""
|
""" Decode Specific Metadata for Ascend"""
|
||||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||||
batch_seq_mask: torch.Tensor = None
|
batch_seq_mask: torch.Tensor = None
|
||||||
|
block_tables: torch.Tensor = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -237,13 +255,10 @@ class AscendAttentionMetadataBuilder:
|
|||||||
self.max_num_blocks_per_req = cdiv(
|
self.max_num_blocks_per_req = cdiv(
|
||||||
self.model_config.max_model_len,
|
self.model_config.max_model_len,
|
||||||
AscendAttentionBackend.get_supported_block_size()[0])
|
AscendAttentionBackend.get_supported_block_size()[0])
|
||||||
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
|
self.batch_seq_mask_buf = torch.empty(
|
||||||
'decode_max_num_seqs', 0)
|
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||||
max_num_seqs = max(vllm_config.scheduler_config.max_num_seqs,
|
dtype=torch.uint8,
|
||||||
decode_max_num_seqs)
|
device=device)
|
||||||
self.batch_seq_mask_buf = torch.empty(max_num_seqs,
|
|
||||||
dtype=torch.uint8,
|
|
||||||
device=device)
|
|
||||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||||
) if prefill_context_parallel_enable() else 1
|
) if prefill_context_parallel_enable() else 1
|
||||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||||
@@ -263,6 +278,27 @@ class AscendAttentionMetadataBuilder:
|
|||||||
|
|
||||||
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
||||||
|
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||||
|
self.block_size - 1) // self.block_size
|
||||||
|
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||||
|
if self.chunked_prefill_enabled:
|
||||||
|
self.chunked_prefill_workspace_size = min(
|
||||||
|
# Max sure there is enough for 8 full length request or at least
|
||||||
|
# 4 pages of cache per request
|
||||||
|
max(8 * self.model_config.max_model_len,
|
||||||
|
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||||
|
128 * 1024)
|
||||||
|
assert self.chunked_prefill_workspace_size >= \
|
||||||
|
scheduler_config.max_num_seqs * self.block_size
|
||||||
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
|
(self.chunked_prefill_workspace_size,
|
||||||
|
self.model_config.get_head_size()),
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
def reorder_batch(self, input_batch,
|
def reorder_batch(self, input_batch,
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
@@ -279,9 +315,8 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_reqs
|
num_reqs
|
||||||
+ 1]
|
+ 1]
|
||||||
|
|
||||||
decode_threshold = 1
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||||
assert num_decodes + num_prefills == num_reqs
|
assert num_decodes + num_prefills == num_reqs
|
||||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||||
|
|
||||||
@@ -302,6 +337,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||||
num_reqs
|
num_reqs
|
||||||
+ 1]
|
+ 1]
|
||||||
|
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||||
|
|
||||||
if attn_state == AscendAttentionState.DecodeOnly and \
|
if attn_state == AscendAttentionState.DecodeOnly and \
|
||||||
common_attn_metadata.num_input_tokens > num_actual_tokens:
|
common_attn_metadata.num_input_tokens > num_actual_tokens:
|
||||||
@@ -338,16 +374,35 @@ class AscendAttentionMetadataBuilder:
|
|||||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
|
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
if num_prefills > 0:
|
decode_metadata = None
|
||||||
pcp_metadata = None
|
if common_long_seq_metadata is not None:
|
||||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
chunked_context_metadata = None
|
||||||
if common_long_seq_metadata is not None:
|
if num_prefills > 0:
|
||||||
|
query_lens = query_lens[num_decode_tokens:]
|
||||||
|
context_lens_cpu = num_computed_tokens_cpu[
|
||||||
|
num_decodes:num_reqs]
|
||||||
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
|
pcp_size = get_prefill_context_model_parallel_world_size(
|
||||||
|
) if prefill_context_parallel_enable() else 1
|
||||||
|
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||||
|
cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk
|
||||||
|
kv_inverse_idx_for_chunk = torch.argsort(
|
||||||
|
cp_kv_recover_idx_for_chunk.to(torch.float32)
|
||||||
|
) if cp_kv_recover_idx_for_chunk is not None else None
|
||||||
|
chunked_context_metadata = \
|
||||||
|
AscendMetadataForPrefill.ChunkedContextMetadata(
|
||||||
|
actual_chunk_seq_lengths=torch.cumsum(query_lens * pcp_size, dim=0),
|
||||||
|
mask_for_non_zero_chunk=common_long_seq_metadata.mask_for_non_zero_chunk,
|
||||||
|
local_chunked_kv_lens=common_long_seq_metadata.local_chunked_kv_lens,
|
||||||
|
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk,
|
||||||
|
kv_inverse_idx_for_chunk=kv_inverse_idx_for_chunk,
|
||||||
|
max_chunk_num=common_long_seq_metadata.max_chunk_num
|
||||||
|
)
|
||||||
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens
|
||||||
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens
|
||||||
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
|
tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens
|
||||||
pcp_size = get_prefill_context_model_parallel_world_size(
|
|
||||||
) if prefill_context_parallel_enable() else 1
|
|
||||||
if pcp_size > 1:
|
if pcp_size > 1:
|
||||||
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0],
|
attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0],
|
||||||
dim=0).tolist()
|
dim=0).tolist()
|
||||||
@@ -355,6 +410,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
head_attn_nomask_seqlens[1], dim=0).tolist()
|
head_attn_nomask_seqlens[1], dim=0).tolist()
|
||||||
tail_attn_nomask_seqlens = torch.cumsum(
|
tail_attn_nomask_seqlens = torch.cumsum(
|
||||||
tail_attn_nomask_seqlens[1], dim=0).tolist()
|
tail_attn_nomask_seqlens[1], dim=0).tolist()
|
||||||
|
|
||||||
pcp_metadata = AscendPCPMetadata(
|
pcp_metadata = AscendPCPMetadata(
|
||||||
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
q_head_idx=common_long_seq_metadata.q_head_idx_tensor,
|
||||||
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor,
|
||||||
@@ -371,16 +427,17 @@ class AscendAttentionMetadataBuilder:
|
|||||||
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
tail_attn_nomask_seqlens=tail_attn_nomask_seqlens,
|
||||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
|
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask)
|
||||||
prefill_metadata = AscendMetadataForPrefill(
|
|
||||||
pcp_metadata=pcp_metadata,
|
|
||||||
pcp_allgather_restore_idx=common_long_seq_metadata.
|
|
||||||
pcp_allgather_restore_idx
|
|
||||||
if common_long_seq_metadata is not None else None)
|
|
||||||
|
|
||||||
decode_metadata = None
|
prefill_metadata = AscendMetadataForPrefill(
|
||||||
if num_decodes > 0:
|
pcp_metadata=pcp_metadata,
|
||||||
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
pcp_allgather_restore_idx=common_long_seq_metadata.
|
||||||
if common_long_seq_metadata is not None:
|
pcp_allgather_restore_idx
|
||||||
|
if common_long_seq_metadata is not None else None,
|
||||||
|
chunked_context=chunked_context_metadata,
|
||||||
|
block_tables=block_table[num_decodes:],
|
||||||
|
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0))
|
||||||
|
|
||||||
|
if num_decodes > 0:
|
||||||
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
|
||||||
assert num_computed_tokens_of_pcp_dcp is not None
|
assert num_computed_tokens_of_pcp_dcp is not None
|
||||||
num_computed_tokens_array = np.array(
|
num_computed_tokens_array = np.array(
|
||||||
@@ -397,7 +454,7 @@ class AscendAttentionMetadataBuilder:
|
|||||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
|
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
|
||||||
batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask.
|
batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask.
|
||||||
shape[0]],
|
shape[0]],
|
||||||
)
|
block_tables=block_table[:num_decodes])
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(
|
attn_metadata = AscendMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
@@ -751,7 +808,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
k_mask: torch.Tensor,
|
k_mask: torch.Tensor,
|
||||||
v_mask: torch.Tensor,
|
v_mask: torch.Tensor,
|
||||||
kv_seqlens_mask: List[int],
|
kv_seqlens_mask: List[int],
|
||||||
mask: torch.Tensor) -> torch.Tensor:
|
mask: torch.Tensor,
|
||||||
|
attn_metadata) -> torch.Tensor:
|
||||||
# nomask Attention
|
# nomask Attention
|
||||||
if k_nomask is not None:
|
if k_nomask is not None:
|
||||||
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
|
attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
@@ -786,30 +844,42 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
softmax_lse_flag=True,
|
softmax_lse_flag=True,
|
||||||
actual_seq_lengths_kv=kv_seqlens_mask,
|
actual_seq_lengths_kv=kv_seqlens_mask,
|
||||||
actual_seq_lengths=q_seqlens)
|
actual_seq_lengths=q_seqlens)
|
||||||
|
|
||||||
# update
|
# update
|
||||||
output = attn_out_mask
|
output = attn_out_mask
|
||||||
|
attn_lse = attn_lse_mask
|
||||||
if k_nomask is not None:
|
if k_nomask is not None:
|
||||||
T = attn_out_mask.shape[0]
|
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is None:
|
||||||
N = attn_out_mask.shape[1]
|
output = self._npu_attn_out_lse_update(attn_lse_mask,
|
||||||
D = attn_out_mask.shape[2]
|
attn_lse_nomask,
|
||||||
|
attn_out_mask,
|
||||||
|
attn_out_nomask)
|
||||||
|
attn_lse = None
|
||||||
|
else:
|
||||||
|
output, attn_lse = self._update_out_and_lse(
|
||||||
|
torch.stack([attn_out_nomask, attn_out_mask], dim=0),
|
||||||
|
torch.stack([attn_lse_nomask, attn_lse_mask], dim=0))
|
||||||
|
|
||||||
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
|
return output, attn_lse
|
||||||
attn_out_mask, attn_lse_mask)
|
|
||||||
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
|
|
||||||
attn_out_nomask, attn_lse_nomask)
|
|
||||||
attn_out_mask = attn_out_mask.to(torch.float32)
|
|
||||||
attn_out_nomask = attn_out_nomask.to(torch.float32)
|
|
||||||
attn_lse_mask = attn_lse_mask.to(torch.float32)
|
|
||||||
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
|
|
||||||
|
|
||||||
attn_output = [attn_out_nomask, attn_out_mask]
|
|
||||||
attn_lse = [attn_lse_nomask, attn_lse_mask]
|
|
||||||
update_type = 0
|
|
||||||
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
|
|
||||||
update_type)
|
|
||||||
output = output.view(T, N, D)
|
|
||||||
|
|
||||||
|
def _npu_attn_out_lse_update(self, attn_lse_mask, attn_lse_nomask,
|
||||||
|
attn_out_mask, attn_out_nomask):
|
||||||
|
T = attn_out_mask.shape[0]
|
||||||
|
N = attn_out_mask.shape[1]
|
||||||
|
D = attn_out_mask.shape[2]
|
||||||
|
attn_out_mask, attn_lse_mask = self._out_lse_reshape(
|
||||||
|
attn_out_mask, attn_lse_mask)
|
||||||
|
attn_out_nomask, attn_lse_nomask = self._out_lse_reshape(
|
||||||
|
attn_out_nomask, attn_lse_nomask)
|
||||||
|
attn_out_mask = attn_out_mask.to(torch.float32)
|
||||||
|
attn_out_nomask = attn_out_nomask.to(torch.float32)
|
||||||
|
attn_lse_mask = attn_lse_mask.to(torch.float32)
|
||||||
|
attn_lse_nomask = attn_lse_nomask.to(torch.float32)
|
||||||
|
attn_output = [attn_out_nomask, attn_out_mask]
|
||||||
|
attn_lse = [attn_lse_nomask, attn_lse_mask]
|
||||||
|
update_type = 0
|
||||||
|
output, _ = torch_npu.npu_attention_update(attn_lse, attn_output,
|
||||||
|
update_type)
|
||||||
|
output = output.view(T, N, D)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
|
def _forward_prefill_cp(self, query: torch.Tensor, key: torch.Tensor,
|
||||||
@@ -831,7 +901,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||||
|
|
||||||
# 1. Attention calculation in the first half of Q in load balancing
|
# 1. Attention calculation in the first half of Q in load balancing
|
||||||
output_head = self._attention_with_nomask_and_mask(
|
output_heads, lse_heads = self._attention_with_nomask_and_mask(
|
||||||
q=torch.index_select(query, 0, q_head_idx),
|
q=torch.index_select(query, 0, q_head_idx),
|
||||||
q_seqlens=attn_mask_seqlens,
|
q_seqlens=attn_mask_seqlens,
|
||||||
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
|
k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx)
|
||||||
@@ -842,12 +912,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
|
k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx),
|
||||||
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
|
v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx),
|
||||||
kv_seqlens_mask=attn_mask_seqlens,
|
kv_seqlens_mask=attn_mask_seqlens,
|
||||||
mask=mask)
|
mask=mask,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
# 2. the Attention calculation in the latter half of Q in load balancing
|
# 2. the Attention calculation in the latter half of Q in load balancing
|
||||||
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
|
# pcp_rank0: Q3*KV0~KV2 + Q3*KV3
|
||||||
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
|
# pcp_rank1: Q2*KV0~KV1 + Q2*KV2
|
||||||
output_tail = self._attention_with_nomask_and_mask(
|
output_tails, lse_tails = self._attention_with_nomask_and_mask(
|
||||||
q=torch.index_select(query, 0, q_tail_idx),
|
q=torch.index_select(query, 0, q_tail_idx),
|
||||||
q_seqlens=attn_mask_seqlens,
|
q_seqlens=attn_mask_seqlens,
|
||||||
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
|
k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx),
|
||||||
@@ -856,13 +927,17 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
|
k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx),
|
||||||
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
|
v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx),
|
||||||
kv_seqlens_mask=attn_mask_seqlens,
|
kv_seqlens_mask=attn_mask_seqlens,
|
||||||
mask=mask)
|
mask=mask,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
# 3. Combine the output of the first half and second half.
|
|
||||||
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx
|
||||||
output = torch.index_select(
|
output = torch.index_select(
|
||||||
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
|
torch.cat([output_heads, output_tails], dim=0), 0, q_full_idx)
|
||||||
return output
|
attn_lse = None
|
||||||
|
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
|
||||||
|
attn_lse = torch.index_select(
|
||||||
|
torch.cat([lse_heads, lse_tails], dim=0), 0, q_full_idx)
|
||||||
|
return output, attn_lse
|
||||||
|
|
||||||
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
||||||
attn_lse: torch.Tensor) -> torch.Tensor:
|
attn_lse: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -928,7 +1003,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
'softmax_lse_flag':
|
'softmax_lse_flag':
|
||||||
True,
|
True,
|
||||||
'block_table':
|
'block_table':
|
||||||
attn_metadata.block_tables,
|
attn_metadata.decode_meta.block_tables,
|
||||||
'block_size':
|
'block_size':
|
||||||
self.key_cache.shape[1],
|
self.key_cache.shape[1],
|
||||||
'actual_seq_lengths_kv':
|
'actual_seq_lengths_kv':
|
||||||
@@ -1029,8 +1104,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_out = self._npu_attention_update(attn_out_lse_list)
|
attn_out = self._npu_attention_update(attn_out_lse_list)
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
|
def _update_out_and_lse(self, out_list: torch.Tensor,
|
||||||
|
lse_list: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i)
|
||||||
|
Args:
|
||||||
|
out_list: shape = [N, batch_size, num_heads, head_size]
|
||||||
|
lse_list: shape = [N, batch_size, num_heads, 1]
|
||||||
|
Returns:
|
||||||
|
out_final: shape = [batch_size, num_heads, head_size]
|
||||||
|
lse_final: shape = [batch_size, num_heads, 1]
|
||||||
|
"""
|
||||||
|
lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False)
|
||||||
|
out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list,
|
||||||
|
dim=0)
|
||||||
|
return out_final, lse_final
|
||||||
|
|
||||||
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
|
def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor,
|
||||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
value: torch.Tensor, kv_cache: Tuple[torch.Tensor],
|
||||||
|
attn_metadata: AscendMetadata,
|
||||||
output: torch.Tensor) -> torch.Tensor:
|
output: torch.Tensor) -> torch.Tensor:
|
||||||
assert attn_metadata is not None
|
assert attn_metadata is not None
|
||||||
has_decode = attn_metadata.num_decodes > 0
|
has_decode = attn_metadata.num_decodes > 0
|
||||||
@@ -1043,32 +1134,320 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
decode_query, attn_metadata)
|
decode_query, attn_metadata)
|
||||||
output[:num_decode_tokens] = output_decode
|
output[:num_decode_tokens] = output_decode
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
prefill_query = query[num_decode_tokens:]
|
assert attn_metadata.prefill is not None
|
||||||
|
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||||
|
prefill_query = query[
|
||||||
|
num_decode_tokens:num_actual_tokens_pcp_padded]
|
||||||
key = key[self.pcp_size * num_decode_tokens:]
|
key = key[self.pcp_size * num_decode_tokens:]
|
||||||
value = value[self.pcp_size * num_decode_tokens:]
|
value = value[self.pcp_size * num_decode_tokens:]
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
output_prefill = self._forward_prefill_cp(
|
# Scenario of Enabling PCP or PCP&DCP
|
||||||
|
attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp(
|
||||||
prefill_query, key, value, attn_metadata)
|
prefill_query, key, value, attn_metadata)
|
||||||
else:
|
else:
|
||||||
max_prefill_seq_len = attn_metadata.seq_lens[
|
# Scenario of Enabling DCP Individually
|
||||||
attn_metadata.num_decode_tokens:].max().item()
|
attn_output_prefill, attn_lse_prefill = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
if attn_metadata.attn_mask is not None:
|
prefill_query,
|
||||||
attn_metadata.attn_mask = attn_metadata.attn_mask[:
|
key,
|
||||||
max_prefill_seq_len, :
|
value,
|
||||||
max_prefill_seq_len]
|
num_heads=self.num_heads,
|
||||||
else:
|
num_key_value_heads=self.num_kv_heads,
|
||||||
ValueError("Attn_metadata.attn_mask is required")
|
input_layout="TND",
|
||||||
seq_lens_back = attn_metadata.seq_lens
|
atten_mask=attn_metadata.attn_mask,
|
||||||
attn_metadata.seq_lens = attn_metadata.seq_lens[
|
scale=self.scale,
|
||||||
attn_metadata.num_decode_tokens:]
|
sparse_mode=3,
|
||||||
output_prefill = self._forward_prefill_no_cache(
|
antiquant_mode=0,
|
||||||
prefill_query, key, value, attn_metadata,
|
antiquant_scale=None,
|
||||||
output[num_decode_tokens:], prefill_query.shape[0])
|
softmax_lse_flag=True,
|
||||||
attn_metadata.seq_lens = seq_lens_back
|
actual_seq_lengths_kv=attn_metadata.prefill.
|
||||||
output[num_decode_tokens:output_prefill.shape[0] +
|
actual_seq_lengths_q,
|
||||||
num_decode_tokens] = output_prefill
|
actual_seq_lengths=attn_metadata.prefill.
|
||||||
|
actual_seq_lengths_q)
|
||||||
|
|
||||||
|
self._process_chunk_prefill(attn_output_prefill, attn_lse_prefill,
|
||||||
|
kv_cache, prefill_query, attn_metadata)
|
||||||
|
output[num_decode_tokens:attn_output_prefill.shape[0] +
|
||||||
|
num_decode_tokens] = attn_output_prefill
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _process_chunk_prefill(self, current_attn_output_prefill,
|
||||||
|
current_attn_lse_prefill, kv_cache,
|
||||||
|
prefill_query, attn_metadata):
|
||||||
|
if attn_metadata.prefill is not None and attn_metadata.prefill.chunked_context is not None:
|
||||||
|
prefill_query_all = self._prefill_query_all_gather(
|
||||||
|
attn_metadata, prefill_query)
|
||||||
|
attn_output_full_chunk, attn_lse_full_chunk = self._compute_prefill_context(
|
||||||
|
prefill_query_all, kv_cache, attn_metadata)
|
||||||
|
self._update_chunk_attn_out_lse_with_current_attn_out_lse(
|
||||||
|
current_attn_output_prefill, current_attn_lse_prefill,
|
||||||
|
attn_output_full_chunk, attn_lse_full_chunk, prefill_query,
|
||||||
|
attn_metadata)
|
||||||
|
|
||||||
|
def _update_chunk_attn_out_lse_with_current_attn_out_lse(
|
||||||
|
self, current_attn_output_prefill, current_attn_lse_prefill,
|
||||||
|
attn_output_full_chunk, attn_lse_full_chunk, prefill_query,
|
||||||
|
attn_metadata):
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
inverse_idx = attn_metadata.prefill.chunked_context.kv_inverse_idx_for_chunk
|
||||||
|
attn_output_full_chunk = torch.index_select(
|
||||||
|
attn_output_full_chunk, 0, inverse_idx)
|
||||||
|
attn_lse_full_chunk = torch.index_select(attn_lse_full_chunk, 0,
|
||||||
|
inverse_idx)
|
||||||
|
num_tokens = prefill_query.size(0)
|
||||||
|
attn_output_full_chunk = attn_output_full_chunk[
|
||||||
|
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
||||||
|
attn_lse_full_chunk = attn_lse_full_chunk[
|
||||||
|
self.pcp_rank * num_tokens:(self.pcp_rank + 1) * num_tokens, :, :]
|
||||||
|
assert attn_output_full_chunk.shape == current_attn_output_prefill.shape and attn_lse_full_chunk.shape == current_attn_lse_prefill.shape
|
||||||
|
seq_len = attn_metadata.query_lens.detach().clone()
|
||||||
|
filtered_indices = filter_chunked_req_indices(
|
||||||
|
seq_len,
|
||||||
|
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
|
||||||
|
|
||||||
|
attn_output_prefill_filtered = current_attn_output_prefill[
|
||||||
|
filtered_indices, :, :]
|
||||||
|
attn_lse_prefill_filtered = current_attn_lse_prefill[
|
||||||
|
filtered_indices, :, :]
|
||||||
|
attn_output_full_chunk = attn_output_full_chunk[filtered_indices, :, :]
|
||||||
|
attn_lse_full_chunk = attn_lse_full_chunk[filtered_indices, :, :]
|
||||||
|
|
||||||
|
attn_output_filtered = self._npu_attn_out_lse_update(
|
||||||
|
attn_lse_prefill_filtered, attn_lse_full_chunk,
|
||||||
|
attn_output_prefill_filtered, attn_output_full_chunk)
|
||||||
|
current_attn_output_prefill[
|
||||||
|
filtered_indices, :, :] = attn_output_filtered.to(
|
||||||
|
current_attn_output_prefill.dtype)
|
||||||
|
|
||||||
|
def _prefill_query_all_gather(self, attn_metadata, prefill_query):
|
||||||
|
prefill_query_all = get_pcp_group().all_gather(prefill_query.contiguous(),
|
||||||
|
0) \
|
||||||
|
if self.pcp_size > 1 else prefill_query
|
||||||
|
prefill_query_all = torch.index_select(prefill_query_all,
|
||||||
|
0,
|
||||||
|
attn_metadata.prefill.chunked_context.cp_kv_recover_idx_for_chunk) \
|
||||||
|
if self.pcp_size > 1 else prefill_query_all
|
||||||
|
return prefill_query_all
|
||||||
|
|
||||||
|
def _compute_prefill_context(self, query: torch.Tensor,
|
||||||
|
kv_cache: Tuple[torch.Tensor],
|
||||||
|
attn_metadata: AscendMetadata):
|
||||||
|
assert len(kv_cache) > 1
|
||||||
|
assert attn_metadata is not None
|
||||||
|
assert attn_metadata.prefill is not None
|
||||||
|
assert attn_metadata.prefill.chunked_context is not None
|
||||||
|
prefill_metadata = attn_metadata.prefill
|
||||||
|
local_chunked_kv_lens = attn_metadata.prefill.chunked_context.local_chunked_kv_lens
|
||||||
|
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
|
||||||
|
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
|
||||||
|
|
||||||
|
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
|
||||||
|
|
||||||
|
iters = max_chunk_num
|
||||||
|
# Keep the causal mask; do not override to all-ones. [req_id][chunk_id][cp-rank][dcp_rank]
|
||||||
|
context_starts_rank = None
|
||||||
|
|
||||||
|
prefix_output_list = []
|
||||||
|
prefix_lse_list = []
|
||||||
|
for i in range(iters):
|
||||||
|
key, value, seq_lens_current_chunk_rank = self._load_kv_for_chunk(
|
||||||
|
attn_metadata, kv_cache, context_starts_rank, i,
|
||||||
|
local_chunked_kv_lens, prefill_metadata, query)
|
||||||
|
|
||||||
|
# 2. Attention computation
|
||||||
|
if seq_lens_current_chunk_rank is None or torch.all(
|
||||||
|
seq_lens_current_chunk_rank == 0).item():
|
||||||
|
prefix_output = torch.full(
|
||||||
|
(query.size(0), self.num_heads, self.head_size),
|
||||||
|
fill_value=0,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
prefix_lse = torch.full((query.size(0), self.num_heads, 1),
|
||||||
|
fill_value=0,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=query.device)
|
||||||
|
else:
|
||||||
|
|
||||||
|
actual_seq_lengths_kv = torch.cumsum(
|
||||||
|
seq_lens_current_chunk_rank, dim=0).tolist()
|
||||||
|
prefix_output, prefix_lse = torch.ops.npu.npu_fused_infer_attention_score(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_key_value_heads=self.num_kv_heads,
|
||||||
|
input_layout="TND", #
|
||||||
|
atten_mask=None,
|
||||||
|
scale=self.scale,
|
||||||
|
sparse_mode=0,
|
||||||
|
antiquant_mode=0,
|
||||||
|
antiquant_scale=None,
|
||||||
|
softmax_lse_flag=True,
|
||||||
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
|
actual_seq_lengths=attn_metadata.prefill.chunked_context.
|
||||||
|
actual_chunk_seq_lengths)
|
||||||
|
prefix_output_list.append(prefix_output)
|
||||||
|
prefix_lse_list.append(prefix_lse)
|
||||||
|
|
||||||
|
# 3. update attn-out & lse
|
||||||
|
prefix_output, prefix_lse = self._update_attn_out_lse_in_chunks(
|
||||||
|
prefix_output_list, prefix_lse_list)
|
||||||
|
|
||||||
|
self._update_attn_out_lse_in_pcp(attn_metadata, prefix_output,
|
||||||
|
prefix_lse)
|
||||||
|
|
||||||
|
return prefix_output, prefix_lse
|
||||||
|
|
||||||
|
def _update_attn_out_lse_in_chunks(self, prefix_output_list,
|
||||||
|
prefix_lse_list):
|
||||||
|
# update output and lse
|
||||||
|
if len(prefix_output_list) > 1:
|
||||||
|
prefix_output, prefix_lse = self._update_out_and_lse(
|
||||||
|
torch.stack(prefix_output_list, dim=0),
|
||||||
|
torch.stack(prefix_lse_list, dim=0))
|
||||||
|
else:
|
||||||
|
prefix_output = prefix_output_list[0]
|
||||||
|
prefix_lse = prefix_lse_list[0]
|
||||||
|
return prefix_output, prefix_lse
|
||||||
|
|
||||||
|
def _update_attn_out_lse_in_pcp(self, attn_metadata, prefix_output,
|
||||||
|
prefix_lse):
|
||||||
|
# CP dimension all_gather and fusion
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
# filter non-zero chunk part of prefix_output
|
||||||
|
current_seq_lens = attn_metadata.query_lens.detach().clone()
|
||||||
|
current_seq_lens.mul_(self.pcp_size) # q_full
|
||||||
|
current_seq_lens_cpu = current_seq_lens.cpu()
|
||||||
|
filtered_indices = filter_chunked_req_indices(
|
||||||
|
current_seq_lens_cpu,
|
||||||
|
attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk)
|
||||||
|
prefix_output_filtered = prefix_output[filtered_indices, :, :]
|
||||||
|
prefix_lse_filtered = prefix_lse[filtered_indices, :, :]
|
||||||
|
|
||||||
|
out_lse_local = torch.cat(
|
||||||
|
[prefix_output_filtered, prefix_lse_filtered], dim=-1)
|
||||||
|
attn_out_lse_list = [
|
||||||
|
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
|
||||||
|
]
|
||||||
|
dist.all_gather(attn_out_lse_list,
|
||||||
|
out_lse_local,
|
||||||
|
group=self.pcp_group)
|
||||||
|
|
||||||
|
attn_out_lse_allgather = torch.stack(
|
||||||
|
attn_out_lse_list,
|
||||||
|
dim=0) # [pcp, batch_size, num_heads, head_size+1]
|
||||||
|
attn_out_allgather, attn_lse_allgather = torch.split(
|
||||||
|
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
|
||||||
|
prefix_output_filtered, prefix_lse_filtered = self._update_out_and_lse(
|
||||||
|
attn_out_allgather, attn_lse_allgather)
|
||||||
|
|
||||||
|
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
|
||||||
|
prefix_output.dtype)
|
||||||
|
prefix_lse[filtered_indices, :, :] = prefix_lse_filtered.to(
|
||||||
|
prefix_lse.dtype)
|
||||||
|
|
||||||
|
def _load_kv_for_chunk(self, attn_metadata, kv_cache, context_starts_rank,
|
||||||
|
i, local_chunked_kv_lens, prefill_metadata, query):
|
||||||
|
|
||||||
|
cache_key = kv_cache[0]
|
||||||
|
cache_value = kv_cache[1]
|
||||||
|
num_heads = cache_key.size(2)
|
||||||
|
head_size = kv_cache[0].size(-1)
|
||||||
|
|
||||||
|
# 1. Load current query's history key-value
|
||||||
|
seq_lens_current_chunk = attn_metadata.query_lens.detach().clone()
|
||||||
|
num_requests = len(seq_lens_current_chunk)
|
||||||
|
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||||
|
context_starts_rank = torch.zeros(
|
||||||
|
num_requests, dtype=torch.int32, device=query.device
|
||||||
|
) if context_starts_rank is None else context_starts_rank
|
||||||
|
# Calculate tokens each rank should process per request
|
||||||
|
seq_lens_current_chunk_rank = torch.zeros_like(seq_lens_current_chunk,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device)
|
||||||
|
total_toks = 0
|
||||||
|
for req_idx in range(num_requests):
|
||||||
|
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||||
|
continue
|
||||||
|
n_computed_acc = local_chunked_kv_lens[req_idx][i]
|
||||||
|
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
|
||||||
|
seq_lens_current_chunk_rank[req_idx] = n_computed_acc[
|
||||||
|
self.pcp_rank][self.dcp_rank]
|
||||||
|
if total_toks > 0:
|
||||||
|
key = torch.empty(total_toks,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
value = torch.empty(total_toks,
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
|
||||||
|
torch_npu.atb.npu_paged_cache_load(
|
||||||
|
cache_key,
|
||||||
|
cache_value,
|
||||||
|
attn_metadata.prefill.block_tables,
|
||||||
|
seq_lens_current_chunk_rank.to(query.device),
|
||||||
|
seq_starts=
|
||||||
|
context_starts_rank, # slot offsets of current chunk in current iteration
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If current rank has no tokens to process, create empty tensors
|
||||||
|
key = torch.empty(0,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
value = torch.empty(0,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_size,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
seq_lens_current_chunk_rank = torch.zeros(
|
||||||
|
(len(seq_lens_current_chunk), ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device)
|
||||||
|
for req_idx in range(num_requests):
|
||||||
|
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||||
|
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||||
|
continue
|
||||||
|
context_starts_rank[req_idx] += local_chunked_kv_lens[req_idx][i][
|
||||||
|
self.pcp_rank][self.dcp_rank]
|
||||||
|
if self.dcp_size > 1:
|
||||||
|
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
|
||||||
|
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank)
|
||||||
|
|
||||||
|
assert len(req_dcp_sizes) == num_requests and all(
|
||||||
|
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
|
||||||
|
total_toks = np.sum(np.array(req_dcp_sizes))
|
||||||
|
kv_local = torch.cat([key, value], dim=-1)
|
||||||
|
head_dim = kv_local.size(-1)
|
||||||
|
kv_full = torch.empty((total_toks, num_heads, head_dim),
|
||||||
|
device=query.device,
|
||||||
|
dtype=query.dtype)
|
||||||
|
|
||||||
|
kv_full_list = [None for _ in range(self.dcp_size)]
|
||||||
|
dist.all_gather_object(kv_full_list,
|
||||||
|
kv_local,
|
||||||
|
group=self.dcp_group)
|
||||||
|
kv_full_list = [
|
||||||
|
kv for kv in kv_full_list if kv is not None and kv.numel() > 0
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(kv_full_list) > 0:
|
||||||
|
kv_full = torch.cat(kv_full_list, dim=0)
|
||||||
|
key, value = kv_full.split([head_size, head_size], dim=-1)
|
||||||
|
if total_toks == 0:
|
||||||
|
return key, value, None
|
||||||
|
seq_lens_current_chunk_rank = torch.tensor(
|
||||||
|
np.sum(np.array(req_dcp_sizes), axis=1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device) # [reqs]
|
||||||
|
return key, value, seq_lens_current_chunk_rank
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
@@ -1162,7 +1541,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
intermediate_output = self._forward_pcp_dcp(
|
intermediate_output = self._forward_pcp_dcp(
|
||||||
query, key, value, attn_metadata, output)
|
query, key, value, kv_cache, attn_metadata, output)
|
||||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||||
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
||||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||||
@@ -1185,7 +1564,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
intermediate_output = self._forward_prefill_no_cache(
|
intermediate_output = self._forward_prefill_no_cache(
|
||||||
query, key, value, attn_metadata, output, num_tokens)
|
query, key, value, attn_metadata, output, num_tokens)
|
||||||
elif attn_metadata.attn_state == \
|
elif attn_metadata.attn_state == \
|
||||||
AscendAttentionState.PrefillCacheHit:
|
AscendAttentionState.PrefillCacheHit:
|
||||||
intermediate_output = self._forward_prefill_cache_hit(
|
intermediate_output = self._forward_prefill_cache_hit(
|
||||||
query, attn_metadata, output)
|
query, attn_metadata, output)
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
@@ -27,11 +28,14 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
|||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|
||||||
maybe_save_kv_layer_to_connector,
|
# isort: off
|
||||||
split_decodes_and_prefills,
|
from vllm_ascend.attention.utils import (
|
||||||
trans_rope_weight, transdata,
|
AscendCommonAttentionMetadata, extract_req_dcp_by_chunk_pcp,
|
||||||
wait_for_kv_layer_from_connector)
|
filter_chunked_req_indices, maybe_save_kv_layer_to_connector,
|
||||||
|
split_decodes_and_prefills, trans_rope_weight, transdata,
|
||||||
|
wait_for_kv_layer_from_connector)
|
||||||
|
# isort: on
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
update_graph_params_workspaces)
|
update_graph_params_workspaces)
|
||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
@@ -111,6 +115,10 @@ class AscendMLAPrefillMetadata:
|
|||||||
workspace: torch.Tensor
|
workspace: torch.Tensor
|
||||||
chunk_seq_lens: torch.Tensor
|
chunk_seq_lens: torch.Tensor
|
||||||
chunk_seq_lens_npu: torch.Tensor
|
chunk_seq_lens_npu: torch.Tensor
|
||||||
|
mask_for_non_zero_chunk: Optional[list[bool]] = None
|
||||||
|
max_chunk_num: int = 0
|
||||||
|
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||||
|
Optional[list[int]]]]]]]] = None
|
||||||
|
|
||||||
attn_mask: torch.Tensor
|
attn_mask: torch.Tensor
|
||||||
query_lens: torch.Tensor
|
query_lens: torch.Tensor
|
||||||
@@ -125,6 +133,7 @@ class AscendMLAPrefillMetadata:
|
|||||||
sin: torch.Tensor = None
|
sin: torch.Tensor = None
|
||||||
cos: torch.Tensor = None
|
cos: torch.Tensor = None
|
||||||
pcp_metadata: Optional[AscendPCPMetadata] = None
|
pcp_metadata: Optional[AscendPCPMetadata] = None
|
||||||
|
cp_kv_recover_idx_for_chunk: Optional[list[int]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -347,6 +356,10 @@ class AscendMLAMetadataBuilder:
|
|||||||
|
|
||||||
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
num_actual_tokens_pcp_padded = long_seq_metadata.num_actual_tokens_pcp_padded if long_seq_metadata else None
|
||||||
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
|
num_computed_tokens_of_pcp_dcp = long_seq_metadata.num_computed_tokens_of_pcp_dcp if long_seq_metadata else None
|
||||||
|
cp_kv_recover_idx_for_chunk = long_seq_metadata.cp_kv_recover_idx_for_chunk if long_seq_metadata else None
|
||||||
|
local_chunked_kv_lens = long_seq_metadata.local_chunked_kv_lens if long_seq_metadata else None
|
||||||
|
mask_for_non_zero_chunk = long_seq_metadata.mask_for_non_zero_chunk if long_seq_metadata else None
|
||||||
|
max_chunk_num = long_seq_metadata.max_chunk_num if long_seq_metadata else 0
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||||
@@ -359,14 +372,15 @@ class AscendMLAMetadataBuilder:
|
|||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||||
input_positions = common_attn_metadata.positions[:
|
|
||||||
num_actual_tokens].long(
|
|
||||||
)
|
|
||||||
if num_actual_tokens_pcp_padded is None:
|
if num_actual_tokens_pcp_padded is None:
|
||||||
num_actual_tokens_pcp_padded = num_actual_tokens
|
num_actual_tokens_pcp_padded = num_actual_tokens
|
||||||
|
|
||||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens_pcp_padded]
|
||||||
|
input_positions = common_attn_metadata.positions[:
|
||||||
|
num_actual_tokens_pcp_padded].long(
|
||||||
|
)
|
||||||
|
|
||||||
if self.cos_cache is None:
|
if self.cos_cache is None:
|
||||||
self.cos_cache = model.model.layers[
|
self.cos_cache = model.model.layers[
|
||||||
@@ -408,7 +422,8 @@ class AscendMLAMetadataBuilder:
|
|||||||
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
tail_attn_nomask_seqlens=common_long_seq_metadata.
|
||||||
tail_attn_nomask_seqlens,
|
tail_attn_nomask_seqlens,
|
||||||
q_full_idx=common_long_seq_metadata.q_full_idx,
|
q_full_idx=common_long_seq_metadata.q_full_idx,
|
||||||
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask,
|
pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask
|
||||||
|
if long_seq_metadata else None,
|
||||||
pcp_allgather_restore_idx=long_seq_metadata.
|
pcp_allgather_restore_idx=long_seq_metadata.
|
||||||
pcp_allgather_restore_idx if long_seq_metadata else None)
|
pcp_allgather_restore_idx if long_seq_metadata else None)
|
||||||
|
|
||||||
@@ -452,6 +467,9 @@ class AscendMLAMetadataBuilder:
|
|||||||
chunk_seq_lens=chunk_seq_lens,
|
chunk_seq_lens=chunk_seq_lens,
|
||||||
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
chunk_seq_lens_npu=chunk_seq_lens.npu(),
|
||||||
workspace=self.chunked_prefill_workspace,
|
workspace=self.chunked_prefill_workspace,
|
||||||
|
local_chunked_kv_lens=local_chunked_kv_lens,
|
||||||
|
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
|
||||||
|
max_chunk_num=max_chunk_num,
|
||||||
)
|
)
|
||||||
prefill_input_positions = input_positions[tokens_start:]
|
prefill_input_positions = input_positions[tokens_start:]
|
||||||
cos = self.cos_cache[
|
cos = self.cos_cache[
|
||||||
@@ -474,7 +492,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
sin=sin,
|
sin=sin,
|
||||||
cos=cos,
|
cos=cos,
|
||||||
pcp_metadata=pcp_metadata,
|
pcp_metadata=pcp_metadata,
|
||||||
)
|
cp_kv_recover_idx_for_chunk=cp_kv_recover_idx_for_chunk)
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
@@ -887,8 +905,26 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = attn_metadata.prefill
|
||||||
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
||||||
return prefix_output, prefix_lse
|
return prefix_output, prefix_lse
|
||||||
|
local_chunked_kv_lens = prefill_metadata.chunked_context.local_chunked_kv_lens
|
||||||
|
mask_for_non_zero_chunk = prefill_metadata.chunked_context.mask_for_non_zero_chunk
|
||||||
|
max_chunk_num = prefill_metadata.chunked_context.max_chunk_num
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
assert local_chunked_kv_lens is not None and mask_for_non_zero_chunk is not None and max_chunk_num > 0
|
||||||
|
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
prefix_output = torch.zeros(q_nope.shape[0],
|
||||||
|
self.num_heads,
|
||||||
|
self.v_head_dim,
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
prefix_lse = torch.zeros(self.num_heads,
|
||||||
|
q_pe.shape[0],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=q_pe.device)
|
||||||
|
|
||||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
iters = max_chunk_num
|
||||||
|
|
||||||
current_seq_len = torch.tensor(prefill_metadata.query_lens,
|
current_seq_len = torch.tensor(prefill_metadata.query_lens,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
@@ -896,60 +932,305 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
cache_k_pe = kv_c_and_k_pe_cache[1]
|
cache_k_pe = kv_c_and_k_pe_cache[1]
|
||||||
num_heads = cache_k_pe.size(2)
|
num_heads = cache_k_pe.size(2)
|
||||||
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1)
|
||||||
|
# token -> request mapping for building per-token masks when CP>1
|
||||||
|
seq_len1 = torch.tensor(prefill_metadata.query_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_nope.device)
|
||||||
|
seq_len1.mul_(
|
||||||
|
self.pcp_size) # q_full: already padded, divisible by cp_size
|
||||||
|
|
||||||
|
# Select mask: prefer CP prefill mask from metadata; fallback to cached prefill_mask; create if needed.
|
||||||
|
mask_local = None
|
||||||
|
if attn_metadata is not None and attn_metadata.prefill is not None and \
|
||||||
|
attn_metadata.prefill.pcp_metadata is not None and attn_metadata.prefill.pcp_metadata.pcp_prefill_mask is not None:
|
||||||
|
mask_local = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||||
|
else:
|
||||||
|
mask_local = self.prefill_mask
|
||||||
|
if mask_local is None:
|
||||||
|
mask_local = torch.triu(
|
||||||
|
torch.ones(512,
|
||||||
|
512,
|
||||||
|
device=q_nope.device,
|
||||||
|
dtype=q_nope.dtype), 1)
|
||||||
|
self.prefill_mask = mask_local
|
||||||
|
|
||||||
|
# Keep the causal mask; do not override to all-ones.
|
||||||
|
context_starts_rank = None
|
||||||
|
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
## DCP mode: each rank processes its own (cp,dcp) historical context slice per request dimension
|
||||||
|
num_requests = len(seq_len1)
|
||||||
|
assert num_requests == len(local_chunked_kv_lens)
|
||||||
|
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||||
|
context_starts_rank = torch.zeros(
|
||||||
|
num_requests, dtype=torch.int32, device=q_nope.device
|
||||||
|
) if context_starts_rank is None else context_starts_rank
|
||||||
|
|
||||||
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
## Calculate tokens each rank should process per request
|
||||||
i]
|
seq_len2_rank = torch.zeros_like(seq_len1, dtype=torch.int32)
|
||||||
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
total_toks = 0
|
||||||
i]
|
|
||||||
seq_len = torch.stack([current_seq_len, context_seq_len])
|
|
||||||
kv_c_normed = torch.empty(toks,
|
|
||||||
num_heads,
|
|
||||||
latent_kv_dim,
|
|
||||||
dtype=q_nope.dtype,
|
|
||||||
device=q_nope.device)
|
|
||||||
k_pe = torch.empty(toks,
|
|
||||||
num_heads,
|
|
||||||
rope_dim,
|
|
||||||
dtype=q_nope.dtype,
|
|
||||||
device=q_nope.device)
|
|
||||||
|
|
||||||
torch_npu.atb.npu_paged_cache_load(
|
for req_idx in range(num_requests):
|
||||||
cache_kv_c,
|
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||||
cache_k_pe,
|
continue
|
||||||
prefill_metadata.block_table,
|
n_computed_acc = local_chunked_kv_lens[req_idx][i]
|
||||||
context_seq_len_npu,
|
total_toks += n_computed_acc[self.pcp_rank][self.dcp_rank]
|
||||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
seq_len2_rank[req_idx] = n_computed_acc[self.pcp_rank][
|
||||||
key=kv_c_normed,
|
self.dcp_rank]
|
||||||
value=k_pe,
|
|
||||||
)
|
if total_toks > 0:
|
||||||
|
kv_c_normed = torch.empty(total_toks,
|
||||||
|
num_heads,
|
||||||
|
latent_kv_dim,
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
k_pe = torch.empty(total_toks,
|
||||||
|
num_heads,
|
||||||
|
rope_dim,
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
|
||||||
|
torch_npu.atb.npu_paged_cache_load(
|
||||||
|
cache_kv_c,
|
||||||
|
cache_k_pe,
|
||||||
|
prefill_metadata.block_table,
|
||||||
|
seq_len2_rank.to(q_nope.device),
|
||||||
|
seq_starts=
|
||||||
|
context_starts_rank, # slot offsets of current chunk in current iteration
|
||||||
|
key=kv_c_normed,
|
||||||
|
value=k_pe,
|
||||||
|
)
|
||||||
|
seq_len2 = seq_len2_rank.to(q_nope.device)
|
||||||
|
else:
|
||||||
|
# If current rank has no tokens to process, create empty tensors
|
||||||
|
kv_c_normed = torch.empty(0,
|
||||||
|
num_heads,
|
||||||
|
latent_kv_dim,
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
k_pe = torch.empty(0,
|
||||||
|
num_heads,
|
||||||
|
rope_dim,
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
seq_len2 = torch.zeros((len(seq_len1), ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_nope.device)
|
||||||
|
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
|
||||||
|
for req_idx in range(num_requests):
|
||||||
|
# Before dealing with a new chunk, set to zero, and accumulate the start positions as chunk prefill step increases
|
||||||
|
if i >= len(local_chunked_kv_lens[req_idx]):
|
||||||
|
continue
|
||||||
|
context_starts_rank[req_idx] += local_chunked_kv_lens[
|
||||||
|
req_idx][i][self.pcp_rank][self.dcp_rank]
|
||||||
|
else:
|
||||||
|
# Original logic: ChunkPrefill-only mode
|
||||||
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
|
context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
|
||||||
|
i]
|
||||||
|
context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
|
||||||
|
i]
|
||||||
|
seq_len = torch.stack([current_seq_len, context_seq_len])
|
||||||
|
|
||||||
|
kv_c_normed = torch.empty(toks,
|
||||||
|
num_heads,
|
||||||
|
latent_kv_dim,
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
k_pe = torch.empty(toks,
|
||||||
|
num_heads,
|
||||||
|
rope_dim,
|
||||||
|
dtype=q_nope.dtype,
|
||||||
|
device=q_nope.device)
|
||||||
|
|
||||||
|
torch_npu.atb.npu_paged_cache_load(
|
||||||
|
cache_kv_c,
|
||||||
|
cache_k_pe,
|
||||||
|
prefill_metadata.block_table,
|
||||||
|
context_seq_len_npu,
|
||||||
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
|
key=kv_c_normed,
|
||||||
|
value=k_pe,
|
||||||
|
)
|
||||||
|
|
||||||
kv_c_normed = kv_c_normed.squeeze()
|
kv_c_normed = kv_c_normed.squeeze()
|
||||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
if self.dcp_size > 1:
|
||||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
# DCP mode: first all_gather within DCP group, let each rank in CP group share complete sequence blocks
|
||||||
k_nope, v = kv_nope\
|
# Step 1: DCP all_gather latent
|
||||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
kv_c_k_pe_local = torch.cat(
|
||||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
[kv_c_normed, k_pe.squeeze()],
|
||||||
torch_npu.atb.npu_ring_mla(
|
dim=-1) # [local_toks, latent_dim + rope_dim]
|
||||||
q_nope=q_nope,
|
|
||||||
q_rope=q_pe,
|
# Step 2: use all_gather_into_tensor_uneven (gather + cat)
|
||||||
k_nope=k_nope,
|
req_dcp_sizes = extract_req_dcp_by_chunk_pcp(
|
||||||
k_rope=k_pe,
|
local_chunked_kv_lens, i, self.dcp_size, self.pcp_rank
|
||||||
value=v,
|
) # need to know num tokens of each rank in dcp group before all_gather # [reqs, dcp]
|
||||||
mask=self.prefill_mask,
|
assert len(req_dcp_sizes) == num_requests and all(
|
||||||
seqlen=seq_len,
|
len(dcp_arr) == self.dcp_size for dcp_arr in req_dcp_sizes)
|
||||||
head_num=self.num_heads,
|
total_toks = np.sum(np.array(req_dcp_sizes))
|
||||||
kv_head_num=self.num_heads,
|
latent_rope_dim = kv_c_k_pe_local.size(-1)
|
||||||
pre_out=prefix_output,
|
kv_c_k_pe_full = torch.empty((total_toks, latent_rope_dim),
|
||||||
prev_lse=prefix_lse,
|
device=kv_c_k_pe_local.device,
|
||||||
qk_scale=self.scale,
|
dtype=kv_c_k_pe_local.dtype)
|
||||||
kernel_type="kernel_type_high_precision",
|
|
||||||
mask_type="no_mask",
|
kv_c_k_pe_full_list = [None for _ in range(self.dcp_size)]
|
||||||
input_layout="type_bsnd",
|
dist.all_gather_object(kv_c_k_pe_full_list,
|
||||||
calc_type="calc_type_default",
|
kv_c_k_pe_local,
|
||||||
output=prefix_output,
|
group=self.dcp_group)
|
||||||
softmax_lse=prefix_lse)
|
kv_c_k_pe_full_list = [
|
||||||
|
kv_c_k_pe for kv_c_k_pe in kv_c_k_pe_full_list
|
||||||
|
if kv_c_k_pe is not None and kv_c_k_pe.numel() > 0
|
||||||
|
]
|
||||||
|
if len(kv_c_k_pe_full_list) > 0:
|
||||||
|
kv_c_k_pe_full = torch.cat(kv_c_k_pe_full_list, dim=0)
|
||||||
|
if len(kv_c_k_pe_full.shape) == 1:
|
||||||
|
assert total_toks == 1
|
||||||
|
kv_c_k_pe_full = kv_c_k_pe_full.unsqueeze(0)
|
||||||
|
assert kv_c_k_pe_full.shape[
|
||||||
|
0] == total_toks and kv_c_k_pe_full.shape[
|
||||||
|
1] == latent_rope_dim
|
||||||
|
kv_c_normed_full, k_pe_full = torch.split(
|
||||||
|
kv_c_k_pe_full, [latent_kv_dim, rope_dim], dim=-1)
|
||||||
|
|
||||||
|
# Step 3: process complete sequence with TP projection to get current rank's head slice
|
||||||
|
# Case that no kv_cache has been stored on this CP rank(after dcp all_gather), no need to do following computation.
|
||||||
|
if total_toks == 0:
|
||||||
|
continue
|
||||||
|
kv_nope = self.kv_b_proj(kv_c_normed_full)[0].view(
|
||||||
|
-1, self.num_heads,
|
||||||
|
self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
k_nope, v = kv_nope.split(
|
||||||
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
k_pe = k_pe_full.unsqueeze(1).expand((*k_nope.shape[:-1], -1))
|
||||||
|
|
||||||
|
seq_len2 = torch.tensor(np.sum(np.array(req_dcp_sizes),
|
||||||
|
axis=1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_nope.device) # [reqs]
|
||||||
|
seq_len = torch.stack([seq_len1.cpu(), seq_len2.cpu()])
|
||||||
|
else:
|
||||||
|
# Non-DCP mode: use TP-split projection
|
||||||
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
|
||||||
|
-1, self.num_heads,
|
||||||
|
self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
k_nope, v = kv_nope.split(
|
||||||
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||||
|
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
# Case that no kv_cache has been stored on this CP rank, no need to do following computation.
|
||||||
|
if torch.all(seq_len2 == 0).item():
|
||||||
|
continue
|
||||||
|
# PCP mode: first compute this rank's contribution to the chunk
|
||||||
|
if i == 0:
|
||||||
|
torch_npu.atb.npu_ring_mla(
|
||||||
|
q_nope=q_nope,
|
||||||
|
q_rope=q_pe,
|
||||||
|
k_nope=k_nope,
|
||||||
|
k_rope=k_pe,
|
||||||
|
value=v,
|
||||||
|
mask=mask_local,
|
||||||
|
seqlen=seq_len,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
kv_head_num=self.num_heads,
|
||||||
|
pre_out=None,
|
||||||
|
prev_lse=None,
|
||||||
|
qk_scale=self.scale,
|
||||||
|
kernel_type="kernel_type_high_precision",
|
||||||
|
mask_type="no_mask",
|
||||||
|
input_layout="type_bsnd",
|
||||||
|
calc_type="calc_type_first_ring",
|
||||||
|
output=prefix_output,
|
||||||
|
softmax_lse=prefix_lse)
|
||||||
|
continue
|
||||||
|
|
||||||
|
torch_npu.atb.npu_ring_mla(
|
||||||
|
q_nope=q_nope,
|
||||||
|
q_rope=q_pe,
|
||||||
|
k_nope=k_nope,
|
||||||
|
k_rope=k_pe,
|
||||||
|
value=v,
|
||||||
|
mask=mask_local,
|
||||||
|
seqlen=seq_len,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
kv_head_num=self.num_heads,
|
||||||
|
pre_out=prefix_output,
|
||||||
|
prev_lse=prefix_lse,
|
||||||
|
qk_scale=self.scale,
|
||||||
|
kernel_type="kernel_type_high_precision",
|
||||||
|
mask_type="no_mask",
|
||||||
|
input_layout="type_bsnd",
|
||||||
|
calc_type="calc_type_default",
|
||||||
|
output=prefix_output,
|
||||||
|
softmax_lse=prefix_lse)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert not torch.all(context_seq_len == 0).item()
|
||||||
|
# compute this chunk block then update prefix tensors to keep shapes consistent
|
||||||
|
torch_npu.atb.npu_ring_mla(
|
||||||
|
q_nope=q_nope,
|
||||||
|
q_rope=q_pe,
|
||||||
|
k_nope=k_nope,
|
||||||
|
k_rope=k_pe,
|
||||||
|
value=v,
|
||||||
|
mask=mask_local,
|
||||||
|
seqlen=seq_len,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
kv_head_num=self.num_heads,
|
||||||
|
pre_out=prefix_output,
|
||||||
|
prev_lse=prefix_lse,
|
||||||
|
qk_scale=self.scale,
|
||||||
|
kernel_type="kernel_type_high_precision",
|
||||||
|
mask_type="no_mask",
|
||||||
|
input_layout="type_bsnd",
|
||||||
|
calc_type="calc_type_default",
|
||||||
|
output=prefix_output,
|
||||||
|
softmax_lse=prefix_lse)
|
||||||
|
|
||||||
|
# CP dimension all_gather and fusion
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
# filter non-zero chunk part of prefix_output
|
||||||
|
seq_len1_cpu = seq_len1.cpu()
|
||||||
|
filtered_indices = filter_chunked_req_indices(
|
||||||
|
seq_len1_cpu, mask_for_non_zero_chunk)
|
||||||
|
prefix_output_filtered = prefix_output[filtered_indices, :, :]
|
||||||
|
prefix_lse_filtered = prefix_lse[:, filtered_indices]
|
||||||
|
|
||||||
|
# normalize prefix LSE to [bs, heads, 1] for stable updates
|
||||||
|
prefix_lse_filtered_bt = prefix_lse_filtered.permute(
|
||||||
|
1, 0).unsqueeze(-1).contiguous(
|
||||||
|
) if prefix_lse_filtered is not None else None
|
||||||
|
out_lse_local = torch.cat(
|
||||||
|
[prefix_output_filtered, prefix_lse_filtered_bt], dim=-1)
|
||||||
|
out_lse_list = [
|
||||||
|
torch.empty_like(out_lse_local) for _ in range(self.pcp_size)
|
||||||
|
]
|
||||||
|
dist.all_gather(out_lse_list, out_lse_local, group=self.pcp_group)
|
||||||
|
prefix_output_filtered = None
|
||||||
|
prefix_lse_filtered_bt = None
|
||||||
|
for r in range(self.pcp_size):
|
||||||
|
out_lse_r = out_lse_list[r]
|
||||||
|
if torch.all(out_lse_r == 0).item():
|
||||||
|
continue
|
||||||
|
out_r, lse_r = torch.split(out_lse_r, [self.v_head_dim, 1],
|
||||||
|
dim=-1)
|
||||||
|
token_mask = torch.ones([out_r.size(0)],
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=out_r.device)
|
||||||
|
prefix_output_filtered, prefix_lse_filtered_bt = self._update_out_and_lse(
|
||||||
|
prefix_output_filtered, prefix_lse_filtered_bt, out_r,
|
||||||
|
lse_r, token_mask)
|
||||||
|
# convert lse back to [heads, bs]
|
||||||
|
assert prefix_output_filtered is not None and prefix_lse_filtered_bt is not None
|
||||||
|
prefix_lse_filtered = prefix_lse_filtered_bt.squeeze(-1).permute(
|
||||||
|
1, 0).contiguous()
|
||||||
|
|
||||||
|
prefix_output[filtered_indices, :, :] = prefix_output_filtered.to(
|
||||||
|
prefix_output.dtype)
|
||||||
|
prefix_lse[:, filtered_indices] = prefix_lse_filtered.to(
|
||||||
|
prefix_lse.dtype)
|
||||||
|
|
||||||
return prefix_output, prefix_lse
|
return prefix_output, prefix_lse
|
||||||
|
|
||||||
def _forward_prefill(
|
def _forward_prefill(
|
||||||
@@ -1516,7 +1797,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens
|
||||||
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
|
||||||
|
|
||||||
output_head = self._attention_with_mask_and_nomask(
|
output_head, head_lse = self._attention_with_mask_and_nomask(
|
||||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||||
k_nope=k_nope,
|
k_nope=k_nope,
|
||||||
@@ -1528,7 +1809,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||||
mask=mask)
|
mask=mask)
|
||||||
|
|
||||||
output_tail = self._attention_with_mask_and_nomask(
|
output_tail, tail_lse = self._attention_with_mask_and_nomask(
|
||||||
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
||||||
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
|
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
|
||||||
k_nope=k_nope,
|
k_nope=k_nope,
|
||||||
@@ -1544,7 +1825,83 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
output = torch.index_select(
|
output = torch.index_select(
|
||||||
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
|
torch.cat([output_head, output_tail], dim=0), 0, q_full_idx)
|
||||||
|
|
||||||
output = output.reshape([num_tokens, self.num_heads * self.v_head_dim])
|
# Synchronize and reorder LSE for subsequent chunked context accumulation
|
||||||
|
attn_lse = torch.cat([head_lse, tail_lse], dim=1)
|
||||||
|
attn_lse = attn_lse[:, q_full_idx]
|
||||||
|
|
||||||
|
# Post-processing: keep [tokens, H, V] shape and perform chunked context accumulation if needed
|
||||||
|
if attn_metadata.prefill is not None and \
|
||||||
|
attn_metadata.prefill.chunked_context is not None:
|
||||||
|
# q all_gather
|
||||||
|
q_nope_full = get_pcp_group().all_gather(q_nope.contiguous(), 0)
|
||||||
|
q_pe_full = get_pcp_group().all_gather(q_pe.contiguous(), 0)
|
||||||
|
q_nope_full = torch.index_select(
|
||||||
|
q_nope_full, 0,
|
||||||
|
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||||
|
q_pe_full = torch.index_select(
|
||||||
|
q_pe_full, 0,
|
||||||
|
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||||
|
attn_output_pre = output.view(num_tokens, self.num_heads,
|
||||||
|
self.v_head_dim)
|
||||||
|
attn_output_pre_full, attn_lse_full = self._compute_prefill_context(
|
||||||
|
q_nope_full,
|
||||||
|
q_pe_full,
|
||||||
|
kv_c_and_k_pe_cache,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
|
attn_metadata,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
# reorder back && extract output + lse result of each cp rank
|
||||||
|
inverse_idx = torch.argsort(
|
||||||
|
attn_metadata.prefill.cp_kv_recover_idx_for_chunk)
|
||||||
|
attn_output_pre_full = torch.index_select(attn_output_pre_full, 0,
|
||||||
|
inverse_idx)
|
||||||
|
attn_lse_full = torch.index_select(attn_lse_full, 1, inverse_idx)
|
||||||
|
attn_output_pre_new = attn_output_pre_full[
|
||||||
|
self.pcp_rank * num_tokens:(self.pcp_rank + 1) *
|
||||||
|
num_tokens, :, :]
|
||||||
|
attn_lse_new = attn_lse_full[:, self.pcp_rank *
|
||||||
|
num_tokens:(self.pcp_rank + 1) *
|
||||||
|
num_tokens]
|
||||||
|
|
||||||
|
# update(output_origin, output_new)
|
||||||
|
assert attn_output_pre_new.shape == attn_output_pre.shape and attn_lse_new.shape == attn_lse.shape
|
||||||
|
seq_len = torch.tensor(attn_metadata.prefill.query_lens,
|
||||||
|
dtype=torch.int32)
|
||||||
|
mask_for_non_zero_chunk = attn_metadata.prefill.chunked_context.mask_for_non_zero_chunk
|
||||||
|
filtered_indices = filter_chunked_req_indices(
|
||||||
|
seq_len, mask_for_non_zero_chunk)
|
||||||
|
attn_output_pre_filtered = attn_output_pre[filtered_indices, :, :]
|
||||||
|
attn_lse_filtered = attn_lse[:, filtered_indices]
|
||||||
|
attn_output_pre_new = attn_output_pre_new[filtered_indices, :, :]
|
||||||
|
attn_lse_new = attn_lse_new[:, filtered_indices]
|
||||||
|
|
||||||
|
# normalize prefix LSE to [bs, heads, 1] for stable updates
|
||||||
|
attn_lse_filtered = attn_lse_filtered.permute(1, 0).unsqueeze(-1)
|
||||||
|
attn_lse_new = attn_lse_new.permute(1, 0).unsqueeze(-1)
|
||||||
|
token_mask = torch.ones([attn_lse_new.size(0)],
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=attn_lse_new.device)
|
||||||
|
attn_output_pre_filtered, attn_lse_filtered = self._update_out_and_lse(
|
||||||
|
attn_output_pre_filtered, attn_lse_filtered,
|
||||||
|
attn_output_pre_new, attn_lse_new, token_mask)
|
||||||
|
# convert lse back to [heads, bs]
|
||||||
|
attn_lse_filtered = attn_lse_filtered.squeeze(-1).permute(
|
||||||
|
1, 0).contiguous()
|
||||||
|
|
||||||
|
attn_output_pre[
|
||||||
|
filtered_indices, :, :] = attn_output_pre_filtered.to(
|
||||||
|
attn_output_pre.dtype)
|
||||||
|
attn_lse[:,
|
||||||
|
filtered_indices] = attn_lse_filtered.to(attn_lse.dtype)
|
||||||
|
|
||||||
|
attn_output_pre = attn_output_pre.to(q_nope.dtype)
|
||||||
|
output = attn_output_pre.reshape(
|
||||||
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
|
else:
|
||||||
|
output = output.reshape(
|
||||||
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@@ -1588,7 +1945,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
# nomask
|
# nomask
|
||||||
if kv_nomask_idx.shape[0] == 0:
|
if kv_nomask_idx.shape[0] == 0:
|
||||||
return attn_output
|
return attn_output, attn_lse
|
||||||
|
|
||||||
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
|
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
|
||||||
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
|
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
|
||||||
@@ -1611,7 +1968,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
calc_type="calc_type_default",
|
calc_type="calc_type_default",
|
||||||
output=attn_output,
|
output=attn_output,
|
||||||
softmax_lse=attn_lse)
|
softmax_lse=attn_lse)
|
||||||
return attn_output
|
return attn_output, attn_lse
|
||||||
|
|
||||||
def _forward_decode_pcp_dcp(
|
def _forward_decode_pcp_dcp(
|
||||||
self,
|
self,
|
||||||
@@ -1788,3 +2145,33 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
||||||
|
|
||||||
return attn_out_lse_list
|
return attn_out_lse_list
|
||||||
|
|
||||||
|
# TODO use update op to replace this
|
||||||
|
def _update_out_and_lse(
|
||||||
|
self,
|
||||||
|
out: torch.Tensor,
|
||||||
|
lse: torch.Tensor,
|
||||||
|
block_out: torch.Tensor,
|
||||||
|
block_lse: torch.Tensor,
|
||||||
|
mask: torch.Tensor = None,
|
||||||
|
):
|
||||||
|
if out is None:
|
||||||
|
out = block_out.to(torch.float32)
|
||||||
|
lse = block_lse
|
||||||
|
else:
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.ones([block_out.size(0)],
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=block_out.device)
|
||||||
|
out_mask = mask[:, None, None].expand_as(block_out)
|
||||||
|
lse_mask = mask[:, None, None].expand_as(block_lse)
|
||||||
|
block_out = block_out.to(torch.float32)
|
||||||
|
out_without_update = out.clone()
|
||||||
|
lse_without_update = lse.clone()
|
||||||
|
|
||||||
|
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
|
||||||
|
lse = lse - F.logsigmoid(lse - block_lse)
|
||||||
|
# mask
|
||||||
|
out = torch.where(out_mask, out, out_without_update)
|
||||||
|
lse = torch.where(lse_mask, lse, lse_without_update)
|
||||||
|
return out, lse
|
||||||
|
|||||||
@@ -14,10 +14,19 @@ from vllm.forward_context import ForwardContext, get_forward_context
|
|||||||
class AscendPrefillContextParallelMetadata:
|
class AscendPrefillContextParallelMetadata:
|
||||||
pcp_allgather_restore_idx: torch.Tensor = None
|
pcp_allgather_restore_idx: torch.Tensor = None
|
||||||
|
|
||||||
|
cp_kv_recover_idx_for_chunk: torch.Tensor = None
|
||||||
|
|
||||||
num_actual_tokens_pcp_padded: Optional[int] = None
|
num_actual_tokens_pcp_padded: Optional[int] = None
|
||||||
|
|
||||||
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
|
||||||
|
|
||||||
|
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[Optional[
|
||||||
|
list[int]]]]]]]] = None
|
||||||
|
|
||||||
|
mask_for_non_zero_chunk: Optional[List[bool]] = None
|
||||||
|
|
||||||
|
max_chunk_num: int = 0
|
||||||
|
|
||||||
q_head_idx_tensor: torch.Tensor = None
|
q_head_idx_tensor: torch.Tensor = None
|
||||||
|
|
||||||
q_tail_idx_tensor: torch.Tensor = None
|
q_tail_idx_tensor: torch.Tensor = None
|
||||||
@@ -106,6 +115,47 @@ class AscendCommonAttentionMetadata:
|
|||||||
AscendPrefillContextParallelMetadata] = None
|
AscendPrefillContextParallelMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_req_dcp_by_chunk_pcp(lst,
|
||||||
|
chunk_idx,
|
||||||
|
dcp_size,
|
||||||
|
pcp_rank,
|
||||||
|
fill_value=0):
|
||||||
|
num_reqs = len(lst)
|
||||||
|
results: List[List[int]] = []
|
||||||
|
for i in range(num_reqs):
|
||||||
|
if len(lst[i]) == 0 or chunk_idx >= len(lst[i]):
|
||||||
|
# empty req or this req has no corresponding chunk, fill 0
|
||||||
|
results.append([fill_value] * dcp_size)
|
||||||
|
continue
|
||||||
|
dcp_values = lst[i][chunk_idx][pcp_rank]
|
||||||
|
results.append(dcp_values)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def filter_chunked_req_indices(
|
||||||
|
seq_len: torch.Tensor,
|
||||||
|
mask_for_non_zero_chunk: Optional[List[bool]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
filter the reqs which are doing real chunk_prefill.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_len: contains multi-req length: [req0_len, req1_len, ...]
|
||||||
|
mask_for_non_zero_chunk: [True, False, True, False, ...]
|
||||||
|
Returns:
|
||||||
|
filtered_indices: the real chunked req's indices
|
||||||
|
"""
|
||||||
|
assert mask_for_non_zero_chunk is not None and len(seq_len) == len(
|
||||||
|
mask_for_non_zero_chunk)
|
||||||
|
offsets = torch.cumsum(torch.cat([torch.tensor([0]), seq_len[:-1]]), dim=0)
|
||||||
|
filtered_indices = torch.cat([
|
||||||
|
torch.arange(offsets[i], offsets[i] + seq_len[i])
|
||||||
|
for i in range(len(mask_for_non_zero_chunk))
|
||||||
|
if mask_for_non_zero_chunk[i]
|
||||||
|
])
|
||||||
|
return filtered_indices
|
||||||
|
|
||||||
|
|
||||||
def split_decodes_and_prefills(
|
def split_decodes_and_prefills(
|
||||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
decode_threshold: int = 1,
|
decode_threshold: int = 1,
|
||||||
|
|||||||
@@ -77,14 +77,6 @@ class BlockTable:
|
|||||||
self.block_table_np = self.block_table_cpu.numpy()
|
self.block_table_np = self.block_table_cpu.numpy()
|
||||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
|
||||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
try:
|
try:
|
||||||
self.pcp_world_size = get_pcp_group(
|
self.pcp_world_size = get_pcp_group(
|
||||||
).world_size if prefill_context_parallel_enable() else 1
|
).world_size if prefill_context_parallel_enable() else 1
|
||||||
@@ -98,6 +90,20 @@ class BlockTable:
|
|||||||
self.dcp_rank = 0
|
self.dcp_rank = 0
|
||||||
self.pcp_world_size = 1
|
self.pcp_world_size = 1
|
||||||
self.pcp_rank = 0
|
self.pcp_rank = 0
|
||||||
|
|
||||||
|
self.slot_mapping_cpu = torch.zeros(
|
||||||
|
self.max_num_batched_tokens +
|
||||||
|
2 * self.pcp_world_size * self.max_num_reqs,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=self.pin_memory)
|
||||||
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
|
self.slot_mapping = torch.zeros(
|
||||||
|
self.max_num_batched_tokens +
|
||||||
|
2 * self.pcp_world_size * self.max_num_reqs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
self.kernel_sizes = kernel_sizes
|
self.kernel_sizes = kernel_sizes
|
||||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||||
|
|
||||||
@@ -148,7 +154,7 @@ class BlockTable:
|
|||||||
if self.dcp_world_size * self.pcp_world_size > 1:
|
if self.dcp_world_size * self.pcp_world_size > 1:
|
||||||
# Note(hc): The DCP implement store kvcache with an interleave
|
# Note(hc): The DCP implement store kvcache with an interleave
|
||||||
# style, the kvcache for the token whose token_idx is i is
|
# style, the kvcache for the token whose token_idx is i is
|
||||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
# always stored on the GPU whose dcp_rank equals i % pcp_world_size:
|
||||||
|
|
||||||
# Use a "virtual block" which equals to world_size * block_size
|
# Use a "virtual block" which equals to world_size * block_size
|
||||||
# for block_table_indices calculation.
|
# for block_table_indices calculation.
|
||||||
@@ -268,12 +274,12 @@ class MultiGroupBlockTable:
|
|||||||
# must be multiplied by dcp_world_size.
|
# must be multiplied by dcp_world_size.
|
||||||
try:
|
try:
|
||||||
dcp_world_size = get_dcp_group().world_size
|
dcp_world_size = get_dcp_group().world_size
|
||||||
cp_world_size = get_pcp_group(
|
pcp_world_size = get_pcp_group(
|
||||||
).world_size if prefill_context_parallel_enable() else 1
|
).world_size if prefill_context_parallel_enable() else 1
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
# DCP might not be initialized in testing
|
# DCP might not be initialized in testing
|
||||||
dcp_world_size = 1
|
dcp_world_size = 1
|
||||||
cp_world_size = 1
|
pcp_world_size = 1
|
||||||
|
|
||||||
if kernel_sizes is None:
|
if kernel_sizes is None:
|
||||||
kernel_sizes = [[0]] * len(block_sizes)
|
kernel_sizes = [[0]] * len(block_sizes)
|
||||||
@@ -291,7 +297,7 @@ class MultiGroupBlockTable:
|
|||||||
block_size, max_num_reqs,
|
block_size, max_num_reqs,
|
||||||
max(
|
max(
|
||||||
cdiv(max_model_len,
|
cdiv(max_model_len,
|
||||||
block_size * dcp_world_size * cp_world_size),
|
block_size * dcp_world_size * pcp_world_size),
|
||||||
1 + num_speculative_tokens), max_num_batched_tokens,
|
1 + num_speculative_tokens), max_num_batched_tokens,
|
||||||
pin_memory, device, kernel_size_list,
|
pin_memory, device, kernel_size_list,
|
||||||
cp_kv_cache_interleave_size)
|
cp_kv_cache_interleave_size)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from copy import deepcopy
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
|
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
|
||||||
Union, cast)
|
Tuple, Union, cast)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@@ -471,13 +471,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
self.pcp_allgather_restore_idx = torch.zeros(self.max_num_tokens,
|
self.pcp_allgather_restore_idx = torch.zeros(
|
||||||
dtype=torch.int32,
|
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
|
||||||
device=self.device)
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
self.cp_kv_recover_idx_for_chunk: List[List[int]] = [
|
||||||
|
[] for _ in range(self.pcp_size)
|
||||||
|
]
|
||||||
|
|
||||||
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
|
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
|
||||||
self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
|
self.pcp_padded_slot_mapping = torch.zeros(
|
||||||
dtype=torch.int32,
|
self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs,
|
||||||
device=self.device)
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
self.num_actual_tokens_pcp_padded = 0
|
self.num_actual_tokens_pcp_padded = 0
|
||||||
if self.speculative_config and self.pcp_size > 1:
|
if self.speculative_config and self.pcp_size > 1:
|
||||||
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
|
||||||
@@ -739,7 +745,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
backward_kwargs = {}
|
backward_kwargs = {}
|
||||||
backward_kwargs["mm_features"] = new_req_data.mm_features
|
backward_kwargs["mm_features"] = new_req_data.mm_features
|
||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
# Create request state - PCP/DCP tracking will be computed below
|
||||||
|
req_state = CachedRequestState(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
prompt_embeds=new_req_data.prompt_embeds,
|
prompt_embeds=new_req_data.prompt_embeds,
|
||||||
@@ -750,9 +757,42 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||||
output_token_ids=[],
|
output_token_ids=[],
|
||||||
lora_request=new_req_data.lora_request,
|
lora_request=new_req_data.lora_request,
|
||||||
|
local_chunked_kv_lens=None,
|
||||||
**backward_kwargs,
|
**backward_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute PCP/DCP tracking fields for chunked prefill
|
||||||
|
self.input_batch.local_chunked_kv_lens = [None] * self.max_num_reqs
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
num_computed_tokens = new_req_data.num_computed_tokens
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
# Initialize with starting rank 0
|
||||||
|
temp_start_rank_dict = {req_id: (0, 0)}
|
||||||
|
|
||||||
|
# Compute token distribution for initial tokens
|
||||||
|
current_distribution = self.get_split_computed_tokens(
|
||||||
|
np.array([num_computed_tokens]),
|
||||||
|
request_ids=[req_id],
|
||||||
|
request_start_rank_dict=temp_start_rank_dict,
|
||||||
|
cp_kv_cache_interleave_size=self.parallel_config.
|
||||||
|
cp_kv_cache_interleave_size,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# Update next_pcp_dcp_start_rank
|
||||||
|
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
|
||||||
|
req_id][0]
|
||||||
|
req_state.token_blank_in_last_blk = temp_start_rank_dict[
|
||||||
|
req_id][1]
|
||||||
|
|
||||||
|
req_state.local_chunked_kv_lens = [
|
||||||
|
copy.deepcopy(current_distribution)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# No computed tokens yet
|
||||||
|
req_state.local_chunked_kv_lens = []
|
||||||
|
|
||||||
|
self.requests[req_id] = req_state
|
||||||
|
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
self._init_mrope_positions(self.requests[req_id])
|
self._init_mrope_positions(self.requests[req_id])
|
||||||
@@ -769,8 +809,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||||
|
|
||||||
# Update the cached states.
|
# Update the cached states.
|
||||||
|
prev_num_computed_tokens = req_state.num_computed_tokens
|
||||||
req_state.num_computed_tokens = num_computed_tokens
|
req_state.num_computed_tokens = num_computed_tokens
|
||||||
|
|
||||||
|
# Compute PCP/DCP tracking fields for chunked prefill
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
# If this is the first chunk, initialize tracking fields
|
||||||
|
if req_state.local_chunked_kv_lens is None:
|
||||||
|
req_state.local_chunked_kv_lens = []
|
||||||
|
|
||||||
|
# Compute tokens added in this chunk (not cumulative)
|
||||||
|
chunk_tokens = num_computed_tokens - prev_num_computed_tokens
|
||||||
|
|
||||||
|
if chunk_tokens > 0:
|
||||||
|
# Create a temporary dict with this request's starting rank
|
||||||
|
temp_start_rank_dict = {
|
||||||
|
req_id: (req_state.next_pcp_dcp_start_rank,
|
||||||
|
req_state.token_blank_in_last_blk)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compute distribution for this chunk only
|
||||||
|
chunk_distribution = self.get_split_computed_tokens(
|
||||||
|
np.array([chunk_tokens]),
|
||||||
|
request_ids=[req_id],
|
||||||
|
request_start_rank_dict=temp_start_rank_dict,
|
||||||
|
cp_kv_cache_interleave_size=self.parallel_config.
|
||||||
|
cp_kv_cache_interleave_size,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# Update next_pcp_dcp_start_rank for this request
|
||||||
|
req_state.next_pcp_dcp_start_rank = temp_start_rank_dict[
|
||||||
|
req_id][0]
|
||||||
|
req_state.token_blank_in_last_blk = temp_start_rank_dict[
|
||||||
|
req_id][1]
|
||||||
|
|
||||||
|
# Append this chunk's distribution to accumulation list
|
||||||
|
req_state.local_chunked_kv_lens.append(
|
||||||
|
copy.deepcopy(chunk_distribution))
|
||||||
|
|
||||||
if not is_last_rank:
|
if not is_last_rank:
|
||||||
# When using PP, the scheduler sends the sampled tokens back,
|
# When using PP, the scheduler sends the sampled tokens back,
|
||||||
# because there's no direct communication between the first-
|
# because there's no direct communication between the first-
|
||||||
@@ -815,6 +891,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.block_table.append_row(
|
self.input_batch.block_table.append_row(
|
||||||
new_block_ids, req_index)
|
new_block_ids, req_index)
|
||||||
|
|
||||||
|
# Update PCP/DCP tracking fields in input_batch
|
||||||
|
self.input_batch.local_chunked_kv_lens[
|
||||||
|
req_index] = req_state.local_chunked_kv_lens
|
||||||
|
|
||||||
# For the last rank, we don't need to update the token_ids_cpu
|
# For the last rank, we don't need to update the token_ids_cpu
|
||||||
# because the sampled tokens are already cached.
|
# because the sampled tokens are already cached.
|
||||||
if not is_last_rank:
|
if not is_last_rank:
|
||||||
@@ -979,6 +1059,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return None
|
return None
|
||||||
if self.attn_mask_builder is None:
|
if self.attn_mask_builder is None:
|
||||||
raise ValueError("Attn mask builder is None")
|
raise ValueError("Attn mask builder is None")
|
||||||
|
if self.dcp_size > 1:
|
||||||
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
# Pooling situation.
|
# Pooling situation.
|
||||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||||
@@ -1378,6 +1460,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output,
|
scheduler_output,
|
||||||
decode_threshold=self.reorder_batch_threshold)
|
decode_threshold=self.reorder_batch_threshold)
|
||||||
|
|
||||||
|
def generate_kv_idx(self, tokens, scheduler_output):
|
||||||
|
if not self.pcp_size > 1:
|
||||||
|
return
|
||||||
|
self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)]
|
||||||
|
|
||||||
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||||
|
req_id]
|
||||||
|
is_prefill = self.input_batch.num_computed_tokens_cpu[
|
||||||
|
i] < self.input_batch.num_prompt_tokens[i]
|
||||||
|
if is_prefill:
|
||||||
|
num_cp_padded_scheduled_tokens = cdiv(
|
||||||
|
num_scheduled_tokens,
|
||||||
|
2 * self.pcp_size) * (2 * self.pcp_size)
|
||||||
|
full_indices = list(
|
||||||
|
range(self.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||||
|
self.pcp_size * self.dcp_size * self.max_num_reqs))
|
||||||
|
chunk_size = num_cp_padded_scheduled_tokens // (2 *
|
||||||
|
self.pcp_size)
|
||||||
|
num_added_recover_tokens = len(
|
||||||
|
self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size
|
||||||
|
for rank in range(self.pcp_size):
|
||||||
|
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||||
|
full_indices[rank * chunk_size +
|
||||||
|
num_added_recover_tokens:(rank + 1) *
|
||||||
|
chunk_size + num_added_recover_tokens])
|
||||||
|
self.cp_kv_recover_idx_for_chunk[rank].extend(
|
||||||
|
full_indices[num_cp_padded_scheduled_tokens -
|
||||||
|
(rank + 1) * chunk_size +
|
||||||
|
num_added_recover_tokens:
|
||||||
|
num_cp_padded_scheduled_tokens -
|
||||||
|
rank * chunk_size +
|
||||||
|
num_added_recover_tokens])
|
||||||
|
|
||||||
|
cp_kv_recover_idx_for_chunk = torch.from_numpy(
|
||||||
|
np.concatenate(
|
||||||
|
self.cp_kv_recover_idx_for_chunk)).to(device=self.device)
|
||||||
|
cp_kv_recover_idx_for_chunk.copy_(torch.tensor(
|
||||||
|
np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()),
|
||||||
|
non_blocking=True)
|
||||||
|
self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to(
|
||||||
|
torch.float32).argsort().to(torch.int32)
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@@ -1406,7 +1531,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.num_computed_tokens_cpu[req_indices],
|
self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||||
arange,
|
arange,
|
||||||
)
|
)
|
||||||
|
self.generate_kv_idx(tokens, scheduler_output)
|
||||||
self.input_batch.block_table.compute_slot_mapping(
|
self.input_batch.block_table.compute_slot_mapping(
|
||||||
req_indices, positions_np)
|
req_indices, positions_np)
|
||||||
self.input_batch.block_table.commit_slot_mapping(
|
self.input_batch.block_table.commit_slot_mapping(
|
||||||
@@ -1610,15 +1735,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
]
|
]
|
||||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
if self.pcp_size == 1:
|
if self.pcp_size > 1:
|
||||||
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
|
||||||
else:
|
|
||||||
# while pcp > 1, we need the original num_scheduled_tokens before split
|
# while pcp > 1, we need the original num_scheduled_tokens before split
|
||||||
# to calculate discard_requests_mask
|
# to calculate discard_requests_mask
|
||||||
|
tokens_original = [
|
||||||
|
scheduler_output.num_scheduled_tokens[i] for i in req_ids
|
||||||
|
]
|
||||||
original_seq_lens_np = (
|
original_seq_lens_np = (
|
||||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||||
np.array(list(scheduler_output.num_scheduled_tokens.values())))
|
np.array(tokens_original, dtype=np.int32))
|
||||||
discard_requests_mask = original_seq_lens_np < num_tokens_np
|
discard_requests_mask = original_seq_lens_np < num_tokens_np
|
||||||
|
else:
|
||||||
|
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
||||||
|
|
||||||
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||||||
self.num_discarded_requests = len(discard_request_indices)
|
self.num_discarded_requests = len(discard_request_indices)
|
||||||
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||||||
@@ -1762,8 +1891,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output.num_scheduled_tokens)
|
scheduler_output.num_scheduled_tokens)
|
||||||
|
|
||||||
# prepare pcp meta data
|
# prepare pcp meta data
|
||||||
|
# For chunked prefill, use num_scheduled_tokens instead of cumulative seq_lens
|
||||||
|
# to correctly calculate chunk_len in _generate_pcp_metadata
|
||||||
|
if self.vllm_config.scheduler_config.chunked_prefill_enabled and self.pcp_size > 1:
|
||||||
|
# In chunked prefill, seq_lens_for_chunk should be the current chunk size
|
||||||
|
seq_lens_for_chunk = torch.from_numpy(
|
||||||
|
num_scheduled_tokens[:num_reqs])
|
||||||
|
else:
|
||||||
|
# Normal mode: use cumulative sequence lengths
|
||||||
|
seq_lens_for_chunk = seq_lens_cpu
|
||||||
long_seq_metadata = self._generate_pcp_metadata(
|
long_seq_metadata = self._generate_pcp_metadata(
|
||||||
total_num_scheduled_tokens, seq_lens_cpu)
|
total_num_scheduled_tokens, seq_lens_for_chunk, seq_lens_cpu)
|
||||||
# Prepare the attention metadata for each KV cache group and make layers
|
# Prepare the attention metadata for each KV cache group and make layers
|
||||||
# in the same group share the same metadata.
|
# in the same group share the same metadata.
|
||||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||||
@@ -2690,7 +2828,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
long_seq_metadata = self._generate_pcp_metadata(
|
long_seq_metadata = self._generate_pcp_metadata(
|
||||||
num_tokens, self.seq_lens_cpu)
|
num_tokens, self.seq_lens_cpu, self.seq_lens_cpu)
|
||||||
if long_seq_metadata is not None:
|
if long_seq_metadata is not None:
|
||||||
pcp_world_size = get_pcp_group(
|
pcp_world_size = get_pcp_group(
|
||||||
).world_size if prefill_context_parallel_enable() else 1
|
).world_size if prefill_context_parallel_enable() else 1
|
||||||
@@ -4266,23 +4404,149 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
[-1, pcp_world_size, dcp_world_size])
|
[-1, pcp_world_size, dcp_world_size])
|
||||||
return dcp_local_seq_lens
|
return dcp_local_seq_lens
|
||||||
|
|
||||||
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
|
def get_split_computed_tokens(
|
||||||
|
self,
|
||||||
|
num_computed_tokens: np.ndarray,
|
||||||
|
request_ids: Optional[List[str]] = None,
|
||||||
|
request_start_rank_dict: Dict[str, tuple[
|
||||||
|
int, int]] = {}, # tuple: start_rank, tokens_blank_in_this_block
|
||||||
|
cp_kv_cache_interleave_size: int = 1
|
||||||
|
) -> list[Optional[list[Optional[list[int]]]]]:
|
||||||
|
"""Splits computed token counts across dcp and sp dimensions for distributed allocation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_computed_tokens: Number of tokens for each request (current chunk, not cumulative)
|
||||||
|
request_ids: Request IDs to track state
|
||||||
|
request_start_rank_dict: Dict mapping req_id to the starting rank for this chunk.
|
||||||
|
Will be updated with next starting rank after distribution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of [pcp_size][dcp_size] distribution for each request
|
||||||
|
"""
|
||||||
|
self.pcp_world_size = get_pcp_group(
|
||||||
|
).world_size if prefill_context_parallel_enable() else 1
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
num_requests = len(num_computed_tokens)
|
||||||
|
assert request_start_rank_dict is not None and request_ids is not None and len(
|
||||||
|
request_ids) == num_requests
|
||||||
|
local_chunked_kv_lens = [[[0] * self.dcp_world_size
|
||||||
|
for _ in range(self.pcp_world_size)]
|
||||||
|
for _ in range(num_requests)]
|
||||||
|
total_ranks = self.pcp_world_size * self.dcp_world_size
|
||||||
|
|
||||||
|
for req_idx, (req_id, total_tokens) in enumerate(
|
||||||
|
zip(request_ids, num_computed_tokens)):
|
||||||
|
if total_tokens <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get starting rank for this chunk
|
||||||
|
start_rank = 0
|
||||||
|
tokens_blank = 0
|
||||||
|
if request_start_rank_dict is not None:
|
||||||
|
start_rank, tokens_blank = request_start_rank_dict.get(
|
||||||
|
req_id, (0, 0))
|
||||||
|
|
||||||
|
if tokens_blank > 0: # need to continue writing in the last block of previous chunk
|
||||||
|
consumed_tokens = min(tokens_blank, total_tokens)
|
||||||
|
total_tokens -= consumed_tokens
|
||||||
|
tokens_blank -= consumed_tokens
|
||||||
|
pcp_idx = start_rank // self.dcp_world_size
|
||||||
|
dcp_idx = start_rank % self.dcp_world_size
|
||||||
|
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||||
|
dcp_idx] += consumed_tokens
|
||||||
|
if tokens_blank == 0:
|
||||||
|
start_rank = (start_rank + 1) % total_ranks
|
||||||
|
if total_tokens == 0:
|
||||||
|
request_start_rank_dict[req_id] = (start_rank,
|
||||||
|
tokens_blank)
|
||||||
|
continue
|
||||||
|
|
||||||
|
virtual_size = total_ranks * cp_kv_cache_interleave_size
|
||||||
|
base = int(total_tokens) // virtual_size
|
||||||
|
|
||||||
|
# Distribute base tokens to all ranks
|
||||||
|
for rank_idx in range(total_ranks):
|
||||||
|
pcp_idx = rank_idx // self.dcp_world_size
|
||||||
|
dcp_idx = rank_idx % self.dcp_world_size
|
||||||
|
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||||
|
dcp_idx] += base * cp_kv_cache_interleave_size
|
||||||
|
|
||||||
|
remainder = int(total_tokens) % virtual_size
|
||||||
|
if remainder == 0:
|
||||||
|
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
|
||||||
|
continue
|
||||||
|
remain_blocks = cdiv(remainder, cp_kv_cache_interleave_size)
|
||||||
|
assert remain_blocks > 0
|
||||||
|
|
||||||
|
# Distribute remainder tokens starting from start_rank
|
||||||
|
for i in range(remain_blocks):
|
||||||
|
rank = (start_rank + i) % total_ranks
|
||||||
|
pcp_idx = rank // self.dcp_world_size
|
||||||
|
dcp_idx = rank % self.dcp_world_size
|
||||||
|
if i < remain_blocks - 1 or remainder % cp_kv_cache_interleave_size == 0: # not last block or divisible
|
||||||
|
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||||
|
dcp_idx] += 1 * cp_kv_cache_interleave_size
|
||||||
|
tokens_blank = 0
|
||||||
|
else: # if last block and undivisible
|
||||||
|
local_chunked_kv_lens[req_idx][pcp_idx][
|
||||||
|
dcp_idx] += remainder % cp_kv_cache_interleave_size
|
||||||
|
tokens_blank = cp_kv_cache_interleave_size - (
|
||||||
|
remainder % cp_kv_cache_interleave_size)
|
||||||
|
start_rank = (start_rank + remain_blocks - 1) % total_ranks
|
||||||
|
if tokens_blank == 0:
|
||||||
|
start_rank = (start_rank + 1) % total_ranks
|
||||||
|
|
||||||
|
# Update next starting rank for this request
|
||||||
|
request_start_rank_dict[req_id] = (start_rank, tokens_blank)
|
||||||
|
|
||||||
|
return cast(List[Optional[List[Optional[List[int]]]]],
|
||||||
|
local_chunked_kv_lens)
|
||||||
|
|
||||||
|
def _get_chunked_req_mask_and_max_chunk(
|
||||||
|
self,
|
||||||
|
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[
|
||||||
|
Optional[list[int]]]]]]]] = None
|
||||||
|
) -> Tuple[List[bool], int]:
|
||||||
|
"""
|
||||||
|
given 4-d list [req][chunk][pcp][dcp], return:
|
||||||
|
1. if each req has any chunk (list[bool])
|
||||||
|
2. max chunk num along all reqs (int)
|
||||||
|
"""
|
||||||
|
assert local_chunked_kv_lens is not None
|
||||||
|
if len(local_chunked_kv_lens) == 0:
|
||||||
|
return ([], 0)
|
||||||
|
mask_for_non_zero_chunk = [
|
||||||
|
len(req) > 0 for req in local_chunked_kv_lens if req is not None
|
||||||
|
]
|
||||||
|
max_chunk_num = max(
|
||||||
|
(len(req) for req in local_chunked_kv_lens if req is not None),
|
||||||
|
default=0)
|
||||||
|
return mask_for_non_zero_chunk, max_chunk_num
|
||||||
|
|
||||||
|
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens,
|
||||||
|
seq_lens_origin):
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
|
||||||
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
>= self.input_batch.num_prompt_tokens[:num_reqs])
|
||||||
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
|
||||||
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
|
||||||
|
local_chunked_kv_lens = self.input_batch.local_chunked_kv_lens[
|
||||||
|
num_decodes:num_reqs]
|
||||||
|
mask_for_non_zero_chunk, max_chunk_num = self._get_chunked_req_mask_and_max_chunk(
|
||||||
|
local_chunked_kv_lens)
|
||||||
long_seq_metadata = None
|
long_seq_metadata = None
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
long_seq_metadata = AscendPrefillContextParallelMetadata(
|
||||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||||
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens(
|
num_computed_tokens_of_pcp_dcp=self._get_pcp_local_seq_lens(
|
||||||
seq_lens,
|
seq_lens_origin,
|
||||||
self.pcp_size,
|
self.pcp_size,
|
||||||
self.dcp_size,
|
self.dcp_size,
|
||||||
self.parallel_config.cp_kv_cache_interleave_size,
|
self.parallel_config.cp_kv_cache_interleave_size,
|
||||||
).numpy(),
|
).numpy(),
|
||||||
)
|
local_chunked_kv_lens=local_chunked_kv_lens,
|
||||||
|
mask_for_non_zero_chunk=mask_for_non_zero_chunk,
|
||||||
|
max_chunk_num=max_chunk_num)
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
q_head_idx, q_tail_idx = [], []
|
q_head_idx, q_tail_idx = [], []
|
||||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||||
@@ -4393,6 +4657,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
}
|
}
|
||||||
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
|
long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
|
||||||
num_actual_tokens_pcp_padded]
|
num_actual_tokens_pcp_padded]
|
||||||
|
long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk
|
||||||
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
|
||||||
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
|
||||||
long_seq_metadata.q_full_idx = self.q_full_idx
|
long_seq_metadata.q_full_idx = self.q_full_idx
|
||||||
|
|||||||
@@ -73,6 +73,12 @@ class CachedRequestState:
|
|||||||
lora_request: Optional[LoRARequest] = None
|
lora_request: Optional[LoRARequest] = None
|
||||||
prompt_embeds: Optional[torch.Tensor] = None
|
prompt_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# pcp/dcp param
|
||||||
|
local_chunked_kv_lens: Optional[list[Optional[list[Optional[
|
||||||
|
list[int]]]]]] = None # Records computed tokens for each chunk
|
||||||
|
next_pcp_dcp_start_rank: int = 0 # Tracks next starting rank for round-robin distribution
|
||||||
|
token_blank_in_last_blk: int = 0 # if the last block is not full, how many future tokens can be stored
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||||
self.prompt_token_ids, self.prompt_embeds)
|
self.prompt_token_ids, self.prompt_embeds)
|
||||||
@@ -313,6 +319,10 @@ class InputBatch:
|
|||||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||||
|
|
||||||
|
# pcp/dcp parameters
|
||||||
|
self.local_chunked_kv_lens: list[Optional[list[Optional[list[Optional[
|
||||||
|
list[int]]]]]]] = [None] * max_num_reqs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def req_ids(self) -> list[str]:
|
def req_ids(self) -> list[str]:
|
||||||
# None elements should only be present transiently
|
# None elements should only be present transiently
|
||||||
@@ -385,6 +395,9 @@ class InputBatch:
|
|||||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||||
self.block_table.add_row(request.block_ids, req_index)
|
self.block_table.add_row(request.block_ids, req_index)
|
||||||
|
|
||||||
|
# Add PCP/DCP tracking fields
|
||||||
|
self.local_chunked_kv_lens[req_index] = request.local_chunked_kv_lens
|
||||||
|
|
||||||
if sampling_params := request.sampling_params:
|
if sampling_params := request.sampling_params:
|
||||||
if (self.is_spec_decode
|
if (self.is_spec_decode
|
||||||
and is_spec_decode_unsupported(sampling_params)):
|
and is_spec_decode_unsupported(sampling_params)):
|
||||||
@@ -680,6 +693,8 @@ class InputBatch:
|
|||||||
last_req_index]
|
last_req_index]
|
||||||
self.num_computed_tokens_cpu[
|
self.num_computed_tokens_cpu[
|
||||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||||
|
self.local_chunked_kv_lens[
|
||||||
|
empty_index] = self.local_chunked_kv_lens[last_req_index]
|
||||||
self.block_table.move_row(last_req_index, empty_index)
|
self.block_table.move_row(last_req_index, empty_index)
|
||||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||||
last_req_index]
|
last_req_index]
|
||||||
|
|||||||
Reference in New Issue
Block a user