[Perf] vectorize PCP/DCP loops in attention_cp.py (#4944)
### What this PR does / why we need it?
- Add explicit .contiguous() after permute/view to ensure mem-friendly
layout
- Replace nested PCP/DCP Python loops with fully vectorized tensor
operations
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: F.Liu <liufeng248@huawei.com>
Co-authored-by: F.Liu <liufeng248@huawei.com>
This commit is contained in:
@@ -106,6 +106,8 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
self.assertEqual(output.shape[1], 4)
|
self.assertEqual(output.shape[1], 4)
|
||||||
self.assertEqual(output.shape[2], 128)
|
self.assertEqual(output.shape[2], 128)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._PCP')
|
||||||
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
|
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
|
||||||
@patch('vllm.distributed.parallel_state._DCP')
|
@patch('vllm.distributed.parallel_state._DCP')
|
||||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||||
@@ -115,9 +117,10 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
def test_forward_decode_pcp_dcp(self, mock_get_forward_context,
|
def test_forward_decode_pcp_dcp(self, mock_get_forward_context,
|
||||||
mock_all_to_all_single, mock_all_gather,
|
mock_all_to_all_single, mock_all_gather,
|
||||||
mock_npu_fused_infer_attention_score,
|
mock_npu_fused_infer_attention_score,
|
||||||
mock_dcp, mock_get_dcp_group):
|
mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||||
|
mock_pcp_group):
|
||||||
|
|
||||||
def mock_dcp_all_gather_func(tensor, dim):
|
def mock_all_gather_func(tensor, dim):
|
||||||
return torch.cat([tensor, tensor], dim=dim)
|
return torch.cat([tensor, tensor], dim=dim)
|
||||||
|
|
||||||
mock_dcp.world_size = 2
|
mock_dcp.world_size = 2
|
||||||
@@ -126,17 +129,27 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
dcp_group.rank_in_group = 0
|
dcp_group.rank_in_group = 0
|
||||||
dcp_group.world_size = 2
|
dcp_group.world_size = 2
|
||||||
dcp_group.device_group = MagicMock()
|
dcp_group.device_group = MagicMock()
|
||||||
dcp_group.all_gather = mock_dcp_all_gather_func
|
dcp_group.all_gather = mock_all_gather_func
|
||||||
mock_get_dcp_group.return_value = dcp_group
|
mock_get_dcp_group.return_value = dcp_group
|
||||||
|
|
||||||
|
mock_pcp.world_size = 2
|
||||||
|
mock_pcp.rank_in_group = 0
|
||||||
|
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||||
|
pcp_group.rank_in_group = 0
|
||||||
|
pcp_group.world_size = 2
|
||||||
|
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||||
|
pcp_group.all_gather = mock_all_gather_func
|
||||||
|
mock_pcp_group.return_value = pcp_group
|
||||||
|
|
||||||
query = torch.randn(2, 4, 128)
|
query = torch.randn(2, 4, 128)
|
||||||
self.impl.key_cache = torch.randn(100, 128, 1, 128)
|
self.impl.key_cache = torch.randn(100, 128, 1, 128)
|
||||||
self.impl.value_cache = torch.randn(100, 128, 1, 128)
|
self.impl.value_cache = torch.randn(100, 128, 1, 128)
|
||||||
|
|
||||||
def mock_npu_attention_update(attn_out_lse_list):
|
def mock_npu_attention_update(attn_out_lse_list):
|
||||||
mock_output = torch.randn(attn_out_lse_list[0].shape[0],
|
mock_output = torch.randn(
|
||||||
attn_out_lse_list[0].shape[1],
|
attn_out_lse_list.shape[0] // mock_pcp.world_size,
|
||||||
attn_out_lse_list[0].shape[2] - 1)
|
attn_out_lse_list.shape[1] // mock_dcp.world_size,
|
||||||
|
attn_out_lse_list.shape[2] - 1)
|
||||||
return mock_output
|
return mock_output
|
||||||
|
|
||||||
self.impl._npu_attention_update = MagicMock()
|
self.impl._npu_attention_update = MagicMock()
|
||||||
@@ -147,11 +160,11 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||||
input)
|
input)
|
||||||
|
|
||||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
def mock_all_gather_func1(tensor_list, tensor, group=None):
|
||||||
tensor_list[0] = tensor
|
tensor_list[0] = tensor
|
||||||
tensor_list[1] = tensor.clone()
|
tensor_list[1] = tensor.clone()
|
||||||
|
|
||||||
mock_all_gather.side_effect = mock_all_gather_func
|
mock_all_gather.side_effect = mock_all_gather_func1
|
||||||
|
|
||||||
def mock_npu_fused_infer_attention_score_func(query, k_nope, value,
|
def mock_npu_fused_infer_attention_score_func(query, k_nope, value,
|
||||||
**common_kwargs):
|
**common_kwargs):
|
||||||
@@ -202,8 +215,9 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
self.assertEqual(output.shape[1], 8)
|
self.assertEqual(output.shape[1], 8)
|
||||||
self.assertEqual(output.shape[2], 128)
|
self.assertEqual(output.shape[2], 128)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||||
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
||||||
def test_compute_prefill_context(self, mock_npu_attention):
|
def test_compute_prefill_context(self, mock_npu_attention, mock_pcp_group):
|
||||||
|
|
||||||
block_num = 100
|
block_num = 100
|
||||||
block_size = 128
|
block_size = 128
|
||||||
@@ -238,6 +252,13 @@ class TestAscendAttentionCPImpl(TestBase):
|
|||||||
self.impl._load_kv_for_chunk = MagicMock()
|
self.impl._load_kv_for_chunk = MagicMock()
|
||||||
self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk
|
self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk
|
||||||
|
|
||||||
|
def mock_all_gather_func(tensor, dim):
|
||||||
|
return torch.cat([tensor, tensor], dim=dim)
|
||||||
|
|
||||||
|
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||||
|
pcp_group.all_gather = mock_all_gather_func
|
||||||
|
mock_pcp_group.return_value = pcp_group
|
||||||
|
|
||||||
mock_npu_attention.return_value = torch.randn(batch_size, num_heads,
|
mock_npu_attention.return_value = torch.randn(batch_size, num_heads,
|
||||||
head_size), torch.randn(
|
head_size), torch.randn(
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -666,6 +687,9 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
self.assertIsInstance(out_final, torch.Tensor)
|
self.assertIsInstance(out_final, torch.Tensor)
|
||||||
self.assertIsInstance(lse_final, torch.Tensor)
|
self.assertIsInstance(lse_final, torch.Tensor)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._PCP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch('torch.cat')
|
@patch('torch.cat')
|
||||||
@patch('torch.distributed.all_to_all_single')
|
@patch('torch.distributed.all_to_all_single')
|
||||||
@patch('torch.distributed.all_gather')
|
@patch('torch.distributed.all_gather')
|
||||||
@@ -673,7 +697,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
@patch('torch.split')
|
@patch('torch.split')
|
||||||
def test_update_chunk_attn_out_lse_dcp_pcp_both_greater_than_1(
|
def test_update_chunk_attn_out_lse_dcp_pcp_both_greater_than_1(
|
||||||
self, mock_split, mock_stack, mock_all_gather,
|
self, mock_split, mock_stack, mock_all_gather,
|
||||||
mock_all_to_all_single, mock_cat):
|
mock_all_to_all_single, mock_cat, mock_pcp, mock_get_pcp_group):
|
||||||
# Mock input data
|
# Mock input data
|
||||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||||
@@ -687,6 +711,9 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
mock_stack.return_value = torch.randn(6, 2, 2, 9)
|
mock_stack.return_value = torch.randn(6, 2, 2, 9)
|
||||||
mock_split.return_value = (torch.randn(6, 2, 2,
|
mock_split.return_value = (torch.randn(6, 2, 2,
|
||||||
8), torch.randn(6, 2, 2, 1))
|
8), torch.randn(6, 2, 2, 1))
|
||||||
|
mock_pcp_group = MagicMock()
|
||||||
|
mock_pcp_group.all_gather.return_value = torch.randn(6, 4, 9)
|
||||||
|
mock_get_pcp_group.return_value = mock_pcp_group
|
||||||
|
|
||||||
# Call the method under test
|
# Call the method under test
|
||||||
output, lse = self.impl._update_chunk_attn_out_lse(
|
output, lse = self.impl._update_chunk_attn_out_lse(
|
||||||
@@ -700,10 +727,10 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
|
|
||||||
self.assertEqual(mock_cat.call_count, 1)
|
self.assertEqual(mock_cat.call_count, 1)
|
||||||
mock_all_to_all_single.assert_called_once()
|
mock_all_to_all_single.assert_called_once()
|
||||||
mock_stack.assert_called_once()
|
self.assertEqual(mock_get_pcp_group.call_count, 1)
|
||||||
mock_split.assert_called_once()
|
|
||||||
self.assertEqual(mock_all_gather.call_count, 1)
|
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._PCP')
|
||||||
@patch('torch.cat')
|
@patch('torch.cat')
|
||||||
@patch('torch.chunk')
|
@patch('torch.chunk')
|
||||||
@patch('torch.stack')
|
@patch('torch.stack')
|
||||||
@@ -712,7 +739,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
@patch('torch.distributed.all_gather')
|
@patch('torch.distributed.all_gather')
|
||||||
def test_update_chunk_attn_out_lse_dcp_greater_than_1_only(
|
def test_update_chunk_attn_out_lse_dcp_greater_than_1_only(
|
||||||
self, mock_all_gather, mock_all_to_all_single, mock_split,
|
self, mock_all_gather, mock_all_to_all_single, mock_split,
|
||||||
mock_stack, mock_chunk, mock_cat):
|
mock_stack, mock_chunk, mock_cat, mock_pcp, mock_pcp_group):
|
||||||
# Mock input data
|
# Mock input data
|
||||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||||
@@ -723,7 +750,8 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
|
|
||||||
# Mock output
|
# Mock output
|
||||||
mock_cat.return_value = torch.randn(2, 4, 9)
|
mock_cat.return_value = torch.randn(2, 4, 9)
|
||||||
mock_all_to_all_single.return_value = torch.randn(2, 4, 9)
|
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||||
|
input)
|
||||||
mock_chunk.return_value = [torch.randn(2, 2, 9), torch.randn(2, 2, 9)]
|
mock_chunk.return_value = [torch.randn(2, 2, 9), torch.randn(2, 2, 9)]
|
||||||
mock_stack.return_value = torch.randn(2, 2, 2, 9)
|
mock_stack.return_value = torch.randn(2, 2, 2, 9)
|
||||||
mock_split.return_value = [
|
mock_split.return_value = [
|
||||||
@@ -743,11 +771,10 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
|
|
||||||
self.assertEqual(mock_cat.call_count, 1)
|
self.assertEqual(mock_cat.call_count, 1)
|
||||||
mock_all_to_all_single.assert_called_once()
|
mock_all_to_all_single.assert_called_once()
|
||||||
mock_chunk.assert_called_once()
|
|
||||||
mock_stack.assert_called_once()
|
|
||||||
mock_split.assert_called_once()
|
|
||||||
mock_all_gather.assert_not_called()
|
mock_all_gather.assert_not_called()
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||||
|
@patch('vllm.distributed.parallel_state._PCP')
|
||||||
@patch('torch.cat')
|
@patch('torch.cat')
|
||||||
@patch('torch.stack')
|
@patch('torch.stack')
|
||||||
@patch('torch.split')
|
@patch('torch.split')
|
||||||
@@ -758,7 +785,8 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
)
|
)
|
||||||
def test_update_chunk_attn_out_lse_pcp_greater_than_1_only(
|
def test_update_chunk_attn_out_lse_pcp_greater_than_1_only(
|
||||||
self, mock_update_out_and_lse, mock_all_gather,
|
self, mock_update_out_and_lse, mock_all_gather,
|
||||||
mock_all_to_all_single, mock_split, mock_stack, mock_cat):
|
mock_all_to_all_single, mock_split, mock_stack, mock_cat, mock_pcp,
|
||||||
|
mock_get_pcp_group):
|
||||||
# Mock input data
|
# Mock input data
|
||||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||||
@@ -769,7 +797,9 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
|
|
||||||
# Mock output
|
# Mock output
|
||||||
mock_cat.return_value = torch.randn(2, 4, 9)
|
mock_cat.return_value = torch.randn(2, 4, 9)
|
||||||
mock_all_gather.return_value = [(2, 4, 9), (2, 4, 9)]
|
mock_pcp_group = MagicMock()
|
||||||
|
mock_pcp_group.all_gather.return_value = torch.randn(4, 4, 9)
|
||||||
|
mock_get_pcp_group.return_value = mock_pcp_group
|
||||||
mock_stack.return_value = torch.randn(2, 2, 4, 9)
|
mock_stack.return_value = torch.randn(2, 2, 4, 9)
|
||||||
mock_split.return_value = [
|
mock_split.return_value = [
|
||||||
torch.randn(2, 2, 4, 8),
|
torch.randn(2, 2, 4, 8),
|
||||||
@@ -791,6 +821,4 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
|||||||
|
|
||||||
self.assertEqual(mock_cat.call_count, 1)
|
self.assertEqual(mock_cat.call_count, 1)
|
||||||
mock_all_to_all_single.assert_not_called()
|
mock_all_to_all_single.assert_not_called()
|
||||||
mock_stack.assert_called_once()
|
mock_get_pcp_group.assert_called_once()
|
||||||
mock_split.assert_called_once()
|
|
||||||
mock_all_gather.assert_called_once()
|
|
||||||
|
|||||||
@@ -428,26 +428,36 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
|
||||||
return attn_out, attn_lse
|
return attn_out, attn_lse
|
||||||
|
|
||||||
def _npu_attention_update(
|
def _npu_attention_update(self,
|
||||||
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
|
attn_out_lse: torch.Tensor) -> torch.Tensor:
|
||||||
|
B_total, H_total, D_plus_1 = attn_out_lse.shape
|
||||||
|
S = B_total // self.pcp_size
|
||||||
|
H = H_total // self.dcp_size
|
||||||
|
D = self.head_size
|
||||||
update_type = 0
|
update_type = 0
|
||||||
|
assert D_plus_1 == D + 1
|
||||||
batch = attn_out_lse_list[0].shape[0]
|
# [PCP, S, DCP, H, D+1]
|
||||||
num_heads = attn_out_lse_list[0].shape[1]
|
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
|
||||||
head_dim = attn_out_lse_list[0].shape[2] - 1
|
# [PCP, DCP, S, H, D+1]
|
||||||
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
attn_out_split_cp = []
|
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
|
||||||
attn_lse_split_cp = []
|
x = x.view(-1, S, H, D_plus_1)
|
||||||
|
# Split out lse
|
||||||
for i in attn_out_lse_list:
|
# [N, S, H, D], [N, S, H, 1]
|
||||||
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
|
out_flat, lse_flat = torch.split(x, [D, 1], dim=-1)
|
||||||
*torch.split(i, [self.head_size, 1], dim=-1))
|
# out: [N, S, H, D] -> [N, S*H, D]
|
||||||
attn_out_split_cp.append(attn_out_allgather)
|
# lse: [N, S, H, 1] -> [N, S*H]
|
||||||
attn_lse_split_cp.append(attn_lse_allgather)
|
out_flat = out_flat.flatten(1, 2)
|
||||||
|
lse_flat = lse_flat.squeeze(-1).flatten(1)
|
||||||
|
# unbind to list
|
||||||
|
# [S*H, D]
|
||||||
|
out_list = out_flat.unbind(0)
|
||||||
|
# [S*H]
|
||||||
|
lse_list = lse_flat.unbind(0)
|
||||||
|
|
||||||
attn_out, attn_lse = torch_npu.npu_attention_update(
|
attn_out, attn_lse = torch_npu.npu_attention_update(
|
||||||
attn_lse_split_cp, attn_out_split_cp, update_type)
|
lse_list, out_list, update_type)
|
||||||
attn_out = attn_out.view(batch, num_heads, head_dim)
|
attn_out = attn_out.view(S, H, D)
|
||||||
|
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
@@ -539,17 +549,10 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||||
query, k_nope, value, **common_kwargs)
|
query, k_nope, value, **common_kwargs)
|
||||||
|
|
||||||
out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
|
||||||
None].expand_as(
|
|
||||||
attn_out)
|
|
||||||
attn_out = torch.where(out_mask, 0, attn_out)
|
|
||||||
|
|
||||||
lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
||||||
None].expand_as(
|
None].expand_as(
|
||||||
attn_lse)
|
attn_lse)
|
||||||
attn_lse = torch.where(lse_mask, -torch.inf, attn_lse)
|
attn_lse = torch.where(lse_mask, -torch.inf, attn_lse)
|
||||||
|
|
||||||
attn_out_lse_list = []
|
|
||||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||||
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
|
attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1)
|
||||||
if self.dcp_size > 1:
|
if self.dcp_size > 1:
|
||||||
@@ -559,30 +562,14 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
dist.all_to_all_single(attn_out_lse_all2all,
|
dist.all_to_all_single(attn_out_lse_all2all,
|
||||||
attn_out_lse,
|
attn_out_lse,
|
||||||
group=self.dcp_group)
|
group=self.dcp_group)
|
||||||
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
|
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
|
||||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
attn_out_lse = attn_out_lse_all2all.contiguous()
|
|
||||||
attn_out_lse_list = list(
|
|
||||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
# AllGather out&lse within CP group
|
# AllGather out&lse within CP group
|
||||||
attn_out_lse_list = [
|
attn_out_lse = get_pcp_group().all_gather(
|
||||||
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
|
attn_out_lse.contiguous(), dim=0)
|
||||||
]
|
|
||||||
dist.all_gather(attn_out_lse_list,
|
attn_out = self._npu_attention_update(attn_out_lse)
|
||||||
attn_out_lse,
|
|
||||||
group=self.pcp_group)
|
|
||||||
if self.dcp_size > 1 and self.pcp_size > 1:
|
|
||||||
attn_out_lse_list_pcp_dcp = []
|
|
||||||
for s in attn_out_lse_list:
|
|
||||||
attn_out_lse_list_split = list(
|
|
||||||
torch.chunk(s, self.dcp_size, dim=1))
|
|
||||||
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
|
|
||||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
|
||||||
# Update out&lse
|
|
||||||
attn_out = self._npu_attention_update(attn_out_lse_list)
|
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
def _update_out_and_lse(self, out_list: torch.Tensor,
|
def _update_out_and_lse(self, out_list: torch.Tensor,
|
||||||
@@ -739,35 +726,28 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
dist.all_to_all_single(attn_out_lse_all2all,
|
dist.all_to_all_single(attn_out_lse_all2all,
|
||||||
chunk_attn_out_lse,
|
chunk_attn_out_lse,
|
||||||
group=self.dcp_group)
|
group=self.dcp_group)
|
||||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
chunk_attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
|
||||||
if self.pcp_size > 1:
|
|
||||||
chunk_attn_out_lse = attn_out_lse_all2all.contiguous()
|
|
||||||
|
|
||||||
attn_out_lse_list = list(
|
|
||||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
attn_out_lse_list = [
|
# AllGather out&lse within CP group
|
||||||
torch.empty_like(chunk_attn_out_lse)
|
chunk_attn_out_lse = get_pcp_group().all_gather(
|
||||||
for _ in range(self.pcp_size)
|
chunk_attn_out_lse.contiguous(), dim=0)
|
||||||
]
|
|
||||||
dist.all_gather(attn_out_lse_list,
|
|
||||||
chunk_attn_out_lse,
|
|
||||||
group=self.pcp_group)
|
|
||||||
|
|
||||||
if self.dcp_size > 1 and self.pcp_size > 1:
|
B_total, H_total, D_plus_1 = chunk_attn_out_lse.shape
|
||||||
attn_out_lse_list_pcp_dcp = []
|
S = B_total // self.pcp_size
|
||||||
for s in attn_out_lse_list:
|
H = H_total // self.dcp_size
|
||||||
attn_out_lse_list_split = list(
|
D = self.head_size
|
||||||
torch.chunk(s, self.dcp_size, dim=1))
|
assert D_plus_1 == D + 1
|
||||||
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
|
# [PCP, S, DCP, H, D+1]
|
||||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
x = chunk_attn_out_lse.view(self.pcp_size, S, self.dcp_size, H,
|
||||||
|
D_plus_1)
|
||||||
attn_out_lse_allgather = torch.stack(
|
# [PCP, DCP, S, H, D+1]
|
||||||
attn_out_lse_list,
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
dim=0) # [pcp, batch_size, num_heads, head_size+1]
|
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
|
||||||
attn_out_allgather, attn_lse_allgather = torch.split(
|
x = x.view(-1, S, H, D_plus_1)
|
||||||
attn_out_lse_allgather, [self.head_size, 1], dim=-1)
|
# Split out lse.
|
||||||
|
# [N, S, H, D], [N, S, H, 1]
|
||||||
|
attn_out_allgather, attn_lse_allgather = torch.split(x, [D, 1], dim=-1)
|
||||||
|
|
||||||
prefix_output, prefix_lse = self._update_out_and_lse(
|
prefix_output, prefix_lse = self._update_out_and_lse(
|
||||||
attn_out_allgather, attn_lse_allgather)
|
attn_out_allgather, attn_lse_allgather)
|
||||||
@@ -842,19 +822,21 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
pcp_allgather_restore_idx)
|
pcp_allgather_restore_idx)
|
||||||
key, value = all_kv.split([self.head_size, self.head_size],
|
key, value = all_kv.split([self.head_size, self.head_size],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
|
prefill_key = key[self.pcp_size *
|
||||||
torch_npu._npu_reshape_and_cache(
|
|
||||||
key=key[self.pcp_size * num_decode_tokens:attn_metadata.
|
|
||||||
num_actual_tokens_pcp_padded],
|
|
||||||
value=value[self.pcp_size *
|
|
||||||
num_decode_tokens:attn_metadata.
|
num_decode_tokens:attn_metadata.
|
||||||
num_actual_tokens_pcp_padded],
|
num_actual_tokens_pcp_padded]
|
||||||
|
prefill_value = value[self.pcp_size *
|
||||||
|
num_decode_tokens:attn_metadata.
|
||||||
|
num_actual_tokens_pcp_padded]
|
||||||
|
slot_mapping = attn_metadata.slot_mapping[
|
||||||
|
self.pcp_size * num_decode_tokens:attn_metadata.
|
||||||
|
num_actual_tokens_pcp_padded]
|
||||||
|
torch_npu._npu_reshape_and_cache(key=prefill_key,
|
||||||
|
value=prefill_value,
|
||||||
key_cache=self.key_cache,
|
key_cache=self.key_cache,
|
||||||
value_cache=self.value_cache,
|
value_cache=self.value_cache,
|
||||||
slot_indices=attn_metadata.
|
slot_indices=slot_mapping)
|
||||||
slot_mapping[self.pcp_size *
|
|
||||||
num_decode_tokens:attn_metadata.
|
|
||||||
num_actual_tokens_pcp_padded])
|
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
def forward_impl(
|
def forward_impl(
|
||||||
@@ -879,9 +861,9 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
|||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
num_actual_tokens_pcp_padded = attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
|
||||||
prefill_query = query[
|
prefill_query = query[
|
||||||
num_decode_tokens:num_actual_tokens_pcp_padded]
|
num_decode_tokens:num_actual_tokens_pcp_padded].contiguous()
|
||||||
key = key[self.pcp_size * num_decode_tokens:]
|
key = key[self.pcp_size * num_decode_tokens:].contiguous()
|
||||||
value = value[self.pcp_size * num_decode_tokens:]
|
value = value[self.pcp_size * num_decode_tokens:].contiguous()
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
# Scenario of Enabling PCP or PCP&DCP
|
# Scenario of Enabling PCP or PCP&DCP
|
||||||
attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp(
|
attn_output_prefill, attn_lse_prefill = self._forward_prefill_cp(
|
||||||
|
|||||||
Reference in New Issue
Block a user