MLA prefill preformance optimization (#5456)
### What this PR does / why we need it?
Since the _npu_ring_mla operator deteriorates in long-sequencescenarios,
the long sequence is split into shorter sequences for input to improve
performance.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: pichangping <1337510399@qq.com>
This commit is contained in:
@@ -813,7 +813,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \
|
q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \
|
||||||
kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info(
|
kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info(
|
||||||
rank, pcp_size, nums_tokens_per_rank)
|
rank, pcp_size, nums_tokens_per_rank)
|
||||||
|
kv_with_q_head_nomask_idx = [kv_with_q_head_nomask_idx]
|
||||||
output_head, lse_head = self.impl._attention_with_mask_and_nomask(
|
output_head, lse_head = self.impl._attention_with_mask_and_nomask(
|
||||||
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
q_nope=torch.index_select(q_nope, 0, q_head_idx),
|
||||||
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
q_pe=torch.index_select(q_pe, 0, q_head_idx),
|
||||||
@@ -824,15 +824,16 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
kv_nomask_idx=kv_with_q_head_nomask_idx,
|
||||||
attn_mask_seqlens=torch.tensor(
|
attn_mask_seqlens=torch.tensor(
|
||||||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
|
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
|
||||||
attn_nomask_seqlens=kv_with_q_head_nomask_seqlens,
|
attn_nomask_seqlens=[kv_with_q_head_nomask_seqlens],
|
||||||
mask=mask)
|
mask=mask)
|
||||||
self.assertEqual(output_head.shape,
|
self.assertEqual(output_head.shape,
|
||||||
(q_head_idx.shape[0], num_heads, v_head_dim))
|
(q_head_idx.shape[0], num_heads, v_head_dim))
|
||||||
self.assertEqual(lse_head.shape,
|
self.assertEqual(lse_head.shape,
|
||||||
(num_heads, q_head_idx.shape[0]))
|
(num_heads, q_head_idx.shape[0]))
|
||||||
self.assertEqual(mock_npu_ring_mla.call_count,
|
self.assertEqual(mock_npu_ring_mla.call_count,
|
||||||
1 + (kv_with_q_head_nomask_idx.shape[0] != 0))
|
1 + (len(kv_with_q_head_nomask_idx[0]) != 0))
|
||||||
mock_npu_ring_mla.reset_mock()
|
mock_npu_ring_mla.reset_mock()
|
||||||
|
kv_with_q_tail_nomask_idx = [kv_with_q_tail_nomask_idx]
|
||||||
output_tail, lse_tail = self.impl._attention_with_mask_and_nomask(
|
output_tail, lse_tail = self.impl._attention_with_mask_and_nomask(
|
||||||
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
q_nope=torch.index_select(q_nope, 0, q_tail_idx),
|
||||||
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
|
q_pe=torch.index_select(q_pe, 0, q_tail_idx),
|
||||||
@@ -843,7 +844,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
kv_nomask_idx=kv_with_q_tail_nomask_idx,
|
||||||
attn_mask_seqlens=torch.tensor(
|
attn_mask_seqlens=torch.tensor(
|
||||||
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
|
[chunk_seqlens, chunk_seqlens], dtype=torch.int32),
|
||||||
attn_nomask_seqlens=kv_with_q_tail_nomask_seqlens,
|
attn_nomask_seqlens=[kv_with_q_tail_nomask_seqlens],
|
||||||
mask=mask)
|
mask=mask)
|
||||||
|
|
||||||
self.assertEqual(output_tail.shape,
|
self.assertEqual(output_tail.shape,
|
||||||
@@ -851,7 +852,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.assertEqual(lse_tail.shape,
|
self.assertEqual(lse_tail.shape,
|
||||||
(num_heads, q_tail_idx.shape[0]))
|
(num_heads, q_tail_idx.shape[0]))
|
||||||
self.assertEqual(mock_npu_ring_mla.call_count,
|
self.assertEqual(mock_npu_ring_mla.call_count,
|
||||||
1 + (kv_with_q_tail_nomask_idx.shape[0] != 0))
|
1 + (len(kv_with_q_tail_nomask_idx[0]) != 0))
|
||||||
mock_npu_ring_mla.reset_mock()
|
mock_npu_ring_mla.reset_mock()
|
||||||
|
|
||||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||||
|
|||||||
@@ -320,3 +320,201 @@ def test_generate_pcp_mtp_input(
|
|||||||
target_input_ids_pcp_full)
|
target_input_ids_pcp_full)
|
||||||
assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1],
|
assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1],
|
||||||
target_query_start_loc_pcp_full)
|
target_query_start_loc_pcp_full)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"pcp_world_rank, split_with_q_head_nomask_idx_reqs, split_kv_with_q_tail_nomask_idx_reqs,"
|
||||||
|
"head_attn_nomask_seqlens, chunk_seqlens,"
|
||||||
|
"target_split_q_head, target_split_q_tail, target_head_seqlens, target_tail_seqlens",
|
||||||
|
[
|
||||||
|
# case1: pcp_world_rank=0
|
||||||
|
(0, [[10, 20, 30]], [[40, 50, 60]],
|
||||||
|
torch.tensor([[64], [0]], dtype=torch.int32), [64], [
|
||||||
|
torch.tensor([1, 2, 3], dtype=torch.int32)
|
||||||
|
], [torch.tensor([40, 50, 60], dtype=torch.int32)], [
|
||||||
|
torch.tensor([[64], [0]], dtype=torch.int32)
|
||||||
|
], [torch.tensor([[64], [3]], dtype=torch.int32)]),
|
||||||
|
# case2: pcp_world_rank=1
|
||||||
|
(1, [[1, 2], [3, 4, 5]], [[6, 7], [8, 9, 10]],
|
||||||
|
torch.tensor([[128, 128], [128, 128]], dtype=torch.int32), [128, 128],
|
||||||
|
[torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)], [
|
||||||
|
torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)
|
||||||
|
], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)
|
||||||
|
], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)]),
|
||||||
|
# case3: pcp_world_rank=2
|
||||||
|
(2, [[11, 12, 13, 14], [15, 16]], [[17, 18, 19], [20, 21, 22, 23]],
|
||||||
|
torch.tensor([[256, 256], [512, 512]], dtype=torch.int32), [256, 256],
|
||||||
|
[torch.tensor([11, 12, 13, 14, 15, 16], dtype=torch.int32)], [
|
||||||
|
torch.tensor([17, 18, 19, 20, 21, 22, 23], dtype=torch.int32)
|
||||||
|
], [torch.tensor([[256, 256], [4, 2]], dtype=torch.int32)
|
||||||
|
], [torch.tensor([[256, 256], [3, 4]], dtype=torch.int32)]),
|
||||||
|
# case4: empty input
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
torch.tensor([], dtype=torch.int32).reshape(2, 0),
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
# case5: single element input
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
[[10]],
|
||||||
|
[[40]],
|
||||||
|
torch.tensor([[64], [0]], dtype=torch.int32),
|
||||||
|
[64],
|
||||||
|
[torch.tensor([1, 2, 3], dtype=torch.int32)],
|
||||||
|
[torch.tensor([40], dtype=torch.int32)],
|
||||||
|
[torch.tensor([[64], [0]], dtype=torch.int32)],
|
||||||
|
[torch.tensor([[64], [1]], dtype=torch.int32)],
|
||||||
|
),
|
||||||
|
# case6: pcp_world_rank=3
|
||||||
|
(
|
||||||
|
3,
|
||||||
|
[[1, 2], [3, 4, 5]],
|
||||||
|
[[6, 7], [8, 9, 10]],
|
||||||
|
torch.tensor([[128, 128], [128, 128]], dtype=torch.int32),
|
||||||
|
[128, 128],
|
||||||
|
[torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)],
|
||||||
|
[torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)],
|
||||||
|
[torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)],
|
||||||
|
[torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)],
|
||||||
|
),
|
||||||
|
])
|
||||||
|
def test_split_nomask_idx_tensor_list(
|
||||||
|
pcp_world_rank, split_with_q_head_nomask_idx_reqs,
|
||||||
|
split_kv_with_q_tail_nomask_idx_reqs, head_attn_nomask_seqlens,
|
||||||
|
chunk_seqlens, target_split_q_head, target_split_q_tail,
|
||||||
|
target_head_seqlens, target_tail_seqlens):
|
||||||
|
# Mock input data
|
||||||
|
mock_runner = MagicMock(spec=PCPManager)
|
||||||
|
mock_runner.device = "cpu"
|
||||||
|
mock_runner.pcp_world_rank = 0
|
||||||
|
mock_runner.kv_idx_names = {
|
||||||
|
"kv_with_q_head_nomask_idx_tensor":
|
||||||
|
torch.tensor([1, 2, 3], dtype=torch.int32)
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_runner.pcp_world_rank = pcp_world_rank
|
||||||
|
|
||||||
|
# Mock output
|
||||||
|
mock_runner._split_multi_batch_kv_idx.side_effect = PCPManager._split_multi_batch_kv_idx.__get__(
|
||||||
|
mock_runner, PCPManager)
|
||||||
|
mock_runner._list_to_tensor.side_effect = PCPManager._list_to_tensor.__get__(
|
||||||
|
mock_runner, PCPManager)
|
||||||
|
|
||||||
|
# Call the method under test
|
||||||
|
result = PCPManager._split_nomask_idx_tensor_list(
|
||||||
|
mock_runner,
|
||||||
|
split_with_q_head_nomask_idx_reqs=split_with_q_head_nomask_idx_reqs,
|
||||||
|
split_kv_with_q_tail_nomask_idx_reqs=
|
||||||
|
split_kv_with_q_tail_nomask_idx_reqs,
|
||||||
|
head_attn_nomask_seqlens=head_attn_nomask_seqlens,
|
||||||
|
chunk_seqlens=chunk_seqlens)
|
||||||
|
split_q_head, split_q_tail, head_seqlens, tail_seqlens = result
|
||||||
|
|
||||||
|
# Assert the method call
|
||||||
|
assert len(split_q_head) == len(target_split_q_head)
|
||||||
|
for res, target in zip(split_q_head, target_split_q_head):
|
||||||
|
assert torch.equal(res, target)
|
||||||
|
|
||||||
|
assert len(split_q_tail) == len(target_split_q_tail)
|
||||||
|
for res, target in zip(split_q_tail, target_split_q_tail):
|
||||||
|
assert torch.equal(res, target)
|
||||||
|
|
||||||
|
assert len(head_seqlens) == len(target_head_seqlens)
|
||||||
|
for res, target in zip(head_seqlens, target_head_seqlens):
|
||||||
|
if isinstance(target, torch.Tensor):
|
||||||
|
assert torch.equal(res, target)
|
||||||
|
else:
|
||||||
|
assert res == target
|
||||||
|
|
||||||
|
assert len(tail_seqlens) == len(target_tail_seqlens)
|
||||||
|
for res, target in zip(tail_seqlens, target_tail_seqlens):
|
||||||
|
if isinstance(target, torch.Tensor):
|
||||||
|
assert torch.equal(res, target)
|
||||||
|
else:
|
||||||
|
assert res == target
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"kv_nomask_idx_multi_batch, split_size, expected_merged_idx, expected_merged_len",
|
||||||
|
[
|
||||||
|
# case1: multiple batches + split size greater than batch length
|
||||||
|
(
|
||||||
|
[[0, 1, 2, 3, 4], [5, 6, 7]],
|
||||||
|
2,
|
||||||
|
# expected merged_split_kv_idx_3d
|
||||||
|
[[0, 1, 5, 6], [2, 3, 7], [4]],
|
||||||
|
# expected merged_split_kv_len_2d
|
||||||
|
[[2, 2], [2, 1], [1, 0]],
|
||||||
|
),
|
||||||
|
# case2: single batch + split size greater than batch length
|
||||||
|
(
|
||||||
|
[[0, 1, 2]],
|
||||||
|
5,
|
||||||
|
[[0, 1, 2]],
|
||||||
|
[[3]],
|
||||||
|
),
|
||||||
|
# case3: split size equals maximum batch length
|
||||||
|
(
|
||||||
|
[[0, 1, 2, 3], [5, 6]],
|
||||||
|
4,
|
||||||
|
[[0, 1, 2, 3, 5, 6]],
|
||||||
|
[[4, 2]],
|
||||||
|
),
|
||||||
|
# case4: Split size is 1 (minimum granularity split)
|
||||||
|
(
|
||||||
|
[[0, 1], [2]],
|
||||||
|
1,
|
||||||
|
[[0, 2], [1]],
|
||||||
|
[[1, 1], [1, 0]],
|
||||||
|
),
|
||||||
|
# case6: the batch contains an empty list
|
||||||
|
(
|
||||||
|
[[], [0, 1], [2]],
|
||||||
|
1,
|
||||||
|
[[0, 2], [1]],
|
||||||
|
[[0, 1, 1], [0, 1, 0]],
|
||||||
|
),
|
||||||
|
# case7: empty input
|
||||||
|
(
|
||||||
|
[],
|
||||||
|
2,
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
])
|
||||||
|
def test_split_multi_batch_kv_idx(
|
||||||
|
kv_nomask_idx_multi_batch,
|
||||||
|
split_size,
|
||||||
|
expected_merged_idx,
|
||||||
|
expected_merged_len,
|
||||||
|
):
|
||||||
|
# Mock input data
|
||||||
|
model_runner = MagicMock(spec=PCPManager)
|
||||||
|
|
||||||
|
# Call the method under test
|
||||||
|
result = PCPManager._split_multi_batch_kv_idx(
|
||||||
|
self=model_runner,
|
||||||
|
kv_nomask_idx_multi_batch=kv_nomask_idx_multi_batch,
|
||||||
|
split_size=split_size)
|
||||||
|
|
||||||
|
merged_split_kv_idx_3d, merged_split_kv_len_2d = result
|
||||||
|
|
||||||
|
# Assert the method call
|
||||||
|
assert len(merged_split_kv_idx_3d) == len(expected_merged_idx)
|
||||||
|
|
||||||
|
for t, (actual_seg, expected_seg) in enumerate(
|
||||||
|
zip(merged_split_kv_idx_3d, expected_merged_idx)):
|
||||||
|
assert actual_seg == expected_seg
|
||||||
|
|
||||||
|
assert len(merged_split_kv_len_2d) == len(expected_merged_len)
|
||||||
|
|
||||||
|
for t, (actual_len, expected_len) in enumerate(
|
||||||
|
zip(merged_split_kv_len_2d, expected_merged_len)):
|
||||||
|
assert actual_len == expected_len
|
||||||
|
|||||||
@@ -465,11 +465,18 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def _attention_with_mask_and_nomask(
|
def _attention_with_mask_and_nomask(
|
||||||
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
|
self,
|
||||||
k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor,
|
k_nope: torch.Tensor,
|
||||||
mask: torch.Tensor):
|
k_pe: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_mask_idx: torch.Tensor,
|
||||||
|
kv_nomask_idx: list[torch.Tensor],
|
||||||
|
attn_mask_seqlens: torch.Tensor,
|
||||||
|
attn_nomask_seqlens: list[torch.Tensor],
|
||||||
|
mask: torch.Tensor,
|
||||||
|
):
|
||||||
attn_output = torch.empty(q_nope.shape[0],
|
attn_output = torch.empty(q_nope.shape[0],
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.v_head_dim,
|
self.v_head_dim,
|
||||||
@@ -503,19 +510,22 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
softmax_lse=attn_lse)
|
softmax_lse=attn_lse)
|
||||||
|
|
||||||
# nomask
|
# nomask
|
||||||
if kv_nomask_idx.shape[0] == 0:
|
if not kv_nomask_idx or len(kv_nomask_idx[0]) == 0:
|
||||||
return attn_output, attn_lse
|
return attn_output, attn_lse
|
||||||
|
|
||||||
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
|
for kv_nomask_idx_split, attn_nomask_seqlens_split in zip(
|
||||||
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
|
kv_nomask_idx, attn_nomask_seqlens):
|
||||||
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx)
|
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx_split)
|
||||||
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
|
value_nomask = torch.index_select(value, 0, kv_nomask_idx_split)
|
||||||
|
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx_split)
|
||||||
|
torch_npu.atb.npu_ring_mla(
|
||||||
|
q_nope=q_nope,
|
||||||
q_rope=q_pe,
|
q_rope=q_pe,
|
||||||
k_nope=k_nope_nomask,
|
k_nope=k_nope_nomask,
|
||||||
k_rope=k_pe_nomask,
|
k_rope=k_pe_nomask,
|
||||||
value=value_nomask,
|
value=value_nomask,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
seqlen=attn_nomask_seqlens,
|
seqlen=attn_nomask_seqlens_split,
|
||||||
head_num=self.num_heads,
|
head_num=self.num_heads,
|
||||||
kv_head_num=self.num_heads,
|
kv_head_num=self.num_heads,
|
||||||
pre_out=attn_output,
|
pre_out=attn_output,
|
||||||
|
|||||||
@@ -565,6 +565,8 @@ class PCPManager:
|
|||||||
q_head_idx, q_tail_idx = [], []
|
q_head_idx, q_tail_idx = [], []
|
||||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||||
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
|
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
|
||||||
|
split_with_q_head_nomask_idx_reqs = []
|
||||||
|
split_kv_with_q_tail_nomask_idx_reqs = []
|
||||||
chunk_seqlens = []
|
chunk_seqlens = []
|
||||||
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
|
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
|
||||||
q_req_offset = 0
|
q_req_offset = 0
|
||||||
@@ -590,7 +592,10 @@ class PCPManager:
|
|||||||
(q_head_chunk_id + 1))))
|
(q_head_chunk_id + 1))))
|
||||||
kv_with_q_head_nomask_seqlens.append(chunk_len *
|
kv_with_q_head_nomask_seqlens.append(chunk_len *
|
||||||
q_head_chunk_id)
|
q_head_chunk_id)
|
||||||
|
split_with_q_head_nomask_idx_reqs.append(
|
||||||
|
list(
|
||||||
|
range(kv_req_offset, kv_req_offset +
|
||||||
|
chunk_len * q_head_chunk_id)))
|
||||||
q_tail_idx.extend(
|
q_tail_idx.extend(
|
||||||
list(
|
list(
|
||||||
range(q_req_offset + chunk_len,
|
range(q_req_offset + chunk_len,
|
||||||
@@ -607,21 +612,17 @@ class PCPManager:
|
|||||||
(q_tail_chunk_id + 1))))
|
(q_tail_chunk_id + 1))))
|
||||||
kv_with_q_tail_nomask_seqlens.append(chunk_len *
|
kv_with_q_tail_nomask_seqlens.append(chunk_len *
|
||||||
q_tail_chunk_id)
|
q_tail_chunk_id)
|
||||||
|
split_kv_with_q_tail_nomask_idx_reqs.append(
|
||||||
|
list(
|
||||||
|
range(kv_req_offset, kv_req_offset +
|
||||||
|
chunk_len * q_tail_chunk_id)))
|
||||||
q_req_offset += seq_len
|
q_req_offset += seq_len
|
||||||
kv_req_offset += seq_len * self.pcp_world_size
|
kv_req_offset += seq_len * self.pcp_world_size
|
||||||
|
|
||||||
# Convert lists to tensors and move to device
|
q_head_idx_tensor = self._list_to_tensor(
|
||||||
def _list_to_tensor(lst, device, dtype=torch.int32):
|
q_head_idx, self.device)
|
||||||
tensor_npu = torch.zeros(len(lst),
|
q_tail_idx_tensor = self._list_to_tensor(
|
||||||
dtype=dtype,
|
q_tail_idx, self.device)
|
||||||
device=device)
|
|
||||||
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
|
|
||||||
non_blocking=True)
|
|
||||||
return tensor_npu
|
|
||||||
|
|
||||||
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
|
|
||||||
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
|
|
||||||
self.q_head_idx_tensor = q_head_idx_tensor
|
self.q_head_idx_tensor = q_head_idx_tensor
|
||||||
self.q_tail_idx_tensor = q_tail_idx_tensor
|
self.q_tail_idx_tensor = q_tail_idx_tensor
|
||||||
|
|
||||||
@@ -639,7 +640,7 @@ class PCPManager:
|
|||||||
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
|
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
|
||||||
}
|
}
|
||||||
for key, value in self.kv_idx_names.items():
|
for key, value in self.kv_idx_names.items():
|
||||||
tensor_npu = _list_to_tensor(value, self.device)
|
tensor_npu = self._list_to_tensor(value, self.device)
|
||||||
self.kv_idx_names[key] = tensor_npu
|
self.kv_idx_names[key] = tensor_npu
|
||||||
|
|
||||||
attn_mask_seqlens = torch.tensor(
|
attn_mask_seqlens = torch.tensor(
|
||||||
@@ -650,6 +651,11 @@ class PCPManager:
|
|||||||
tail_attn_nomask_seqlens = torch.tensor(
|
tail_attn_nomask_seqlens = torch.tensor(
|
||||||
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
|
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
if self.vllm_config.model_config.use_mla:
|
||||||
|
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = self._split_nomask_idx_tensor_list(
|
||||||
|
split_with_q_head_nomask_idx_reqs,
|
||||||
|
split_kv_with_q_tail_nomask_idx_reqs,
|
||||||
|
head_attn_nomask_seqlens, chunk_seqlens)
|
||||||
pcp_prefill_mask = attn_mask
|
pcp_prefill_mask = attn_mask
|
||||||
|
|
||||||
self.extra_long_seq_kwargs = {
|
self.extra_long_seq_kwargs = {
|
||||||
@@ -680,5 +686,95 @@ class PCPManager:
|
|||||||
'tail_attn_nomask_seqlens']
|
'tail_attn_nomask_seqlens']
|
||||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||||||
'pcp_prefill_mask']
|
'pcp_prefill_mask']
|
||||||
|
if self.vllm_config.model_config.use_mla:
|
||||||
|
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
|
||||||
|
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
|
||||||
|
long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list
|
||||||
|
long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list
|
||||||
self.long_seq_metadata = long_seq_metadata
|
self.long_seq_metadata = long_seq_metadata
|
||||||
return long_seq_metadata
|
return long_seq_metadata
|
||||||
|
|
||||||
|
def _list_to_tensor(self, lst, device, dtype=torch.int32):
|
||||||
|
tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)
|
||||||
|
tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True)
|
||||||
|
return tensor_npu
|
||||||
|
|
||||||
|
def _split_nomask_idx_tensor_list(self, split_with_q_head_nomask_idx_reqs,
|
||||||
|
split_kv_with_q_tail_nomask_idx_reqs,
|
||||||
|
head_attn_nomask_seqlens, chunk_seqlens):
|
||||||
|
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list= [], []
|
||||||
|
head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = [], []
|
||||||
|
if split_with_q_head_nomask_idx_reqs:
|
||||||
|
#In long-sequence scenarios, the computational cost and latency
|
||||||
|
#of the _npu_ring_mla operator are not proportional, so we split
|
||||||
|
#long sequences into shorter ones to improve performance.
|
||||||
|
split_size = 16 * 1024
|
||||||
|
if self.pcp_world_rank == 0:
|
||||||
|
split_q_head_nomask_idx_list = [
|
||||||
|
self.kv_idx_names['kv_with_q_head_nomask_idx_tensor']
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
split_q_head_nomask_idx_list, split_q_head_nomask_lens_list = self._split_multi_batch_kv_idx(
|
||||||
|
split_with_q_head_nomask_idx_reqs, split_size)
|
||||||
|
split_q_tail_nomask_idx_list, split_q_tail_nomask_lens_list = self._split_multi_batch_kv_idx(
|
||||||
|
split_kv_with_q_tail_nomask_idx_reqs, split_size)
|
||||||
|
|
||||||
|
for q_head_nomask_idx in split_q_head_nomask_idx_list:
|
||||||
|
split_q_head_nomask_idx_tensor_list.append(
|
||||||
|
self._list_to_tensor(q_head_nomask_idx, self.device))
|
||||||
|
|
||||||
|
for q_tail_nomask_idx in split_q_tail_nomask_idx_list:
|
||||||
|
split_q_tail_nomask_idx_tensor_list.append(
|
||||||
|
self._list_to_tensor(q_tail_nomask_idx, self.device))
|
||||||
|
|
||||||
|
if self.pcp_world_rank == 0:
|
||||||
|
head_attn_nomask_seqlens_list = [head_attn_nomask_seqlens]
|
||||||
|
else:
|
||||||
|
for q_head_nomask_lens in split_q_head_nomask_lens_list:
|
||||||
|
head_attn_nomask_seqlens_list.append(
|
||||||
|
torch.tensor([chunk_seqlens, q_head_nomask_lens],
|
||||||
|
dtype=torch.int32))
|
||||||
|
for q_tail_nomask_lens in split_q_tail_nomask_lens_list:
|
||||||
|
tail_attn_nomask_seqlens_list.append(
|
||||||
|
torch.tensor([chunk_seqlens, q_tail_nomask_lens],
|
||||||
|
dtype=torch.int32))
|
||||||
|
return split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list
|
||||||
|
|
||||||
|
def _split_multi_batch_kv_idx(
|
||||||
|
self,
|
||||||
|
kv_nomask_idx_multi_batch,
|
||||||
|
split_size,
|
||||||
|
):
|
||||||
|
batch_lengths = [len(batch) for batch in kv_nomask_idx_multi_batch]
|
||||||
|
max_batch_length = max(batch_lengths) if batch_lengths else 0
|
||||||
|
time = (max_batch_length + split_size - 1) // split_size
|
||||||
|
split_kv_idx_3d = []
|
||||||
|
split_kv_len_2d = []
|
||||||
|
merged_split_kv_idx_3d = []
|
||||||
|
|
||||||
|
for single_batch in kv_nomask_idx_multi_batch:
|
||||||
|
current_batch_split = []
|
||||||
|
current_batch_len = []
|
||||||
|
for t in range(time):
|
||||||
|
start = t * split_size
|
||||||
|
current_segment = single_batch[start:start + split_size]
|
||||||
|
current_batch_split.append(current_segment)
|
||||||
|
current_batch_len.append(len(current_segment))
|
||||||
|
|
||||||
|
split_kv_idx_3d.append(current_batch_split)
|
||||||
|
split_kv_len_2d.append(current_batch_len)
|
||||||
|
|
||||||
|
for time_idx in range(time):
|
||||||
|
current_time_merged = []
|
||||||
|
for batch in split_kv_idx_3d:
|
||||||
|
current_time_merged.extend(batch[time_idx])
|
||||||
|
merged_split_kv_idx_3d.append(current_time_merged)
|
||||||
|
|
||||||
|
def reshape_kv_len_to_time_first(split_kv_len_2d):
|
||||||
|
if not split_kv_len_2d or not split_kv_len_2d[0]:
|
||||||
|
return []
|
||||||
|
return [[batch_len[time_idx] for batch_len in split_kv_len_2d]
|
||||||
|
for time_idx in range(len(split_kv_len_2d[0]))]
|
||||||
|
|
||||||
|
merged_split_kv_len_2d = reshape_kv_len_to_time_first(split_kv_len_2d)
|
||||||
|
return merged_split_kv_idx_3d, merged_split_kv_len_2d
|
||||||
|
|||||||
Reference in New Issue
Block a user