[Perf] vectorize PCP/DCP loops in attention_cp.py (#4944)
### What this PR does / why we need it?
- Add explicit .contiguous() after permute/view to ensure mem-friendly
layout
- Replace nested PCP/DCP Python loops with fully vectorized tensor
operations
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: F.Liu <liufeng248@huawei.com>
Co-authored-by: F.Liu <liufeng248@huawei.com>
This commit is contained in:
@@ -106,6 +106,8 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 4)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP')
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
@@ -115,9 +117,10 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
def test_forward_decode_pcp_dcp(self, mock_get_forward_context,
|
||||
mock_all_to_all_single, mock_all_gather,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_dcp, mock_get_dcp_group):
|
||||
mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_pcp_group):
|
||||
|
||||
def mock_dcp_all_gather_func(tensor, dim):
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_dcp.world_size = 2
|
||||
@@ -126,17 +129,27 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 2
|
||||
dcp_group.device_group = MagicMock()
|
||||
dcp_group.all_gather = mock_dcp_all_gather_func
|
||||
dcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = 0
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 2
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_pcp_group.return_value = pcp_group
|
||||
|
||||
query = torch.randn(2, 4, 128)
|
||||
self.impl.key_cache = torch.randn(100, 128, 1, 128)
|
||||
self.impl.value_cache = torch.randn(100, 128, 1, 128)
|
||||
|
||||
def mock_npu_attention_update(attn_out_lse_list):
|
||||
mock_output = torch.randn(attn_out_lse_list[0].shape[0],
|
||||
attn_out_lse_list[0].shape[1],
|
||||
attn_out_lse_list[0].shape[2] - 1)
|
||||
mock_output = torch.randn(
|
||||
attn_out_lse_list.shape[0] // mock_pcp.world_size,
|
||||
attn_out_lse_list.shape[1] // mock_dcp.world_size,
|
||||
attn_out_lse_list.shape[2] - 1)
|
||||
return mock_output
|
||||
|
||||
self.impl._npu_attention_update = MagicMock()
|
||||
@@ -147,11 +160,11 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
def mock_all_gather_func1(tensor_list, tensor, group=None):
|
||||
tensor_list[0] = tensor
|
||||
tensor_list[1] = tensor.clone()
|
||||
|
||||
mock_all_gather.side_effect = mock_all_gather_func
|
||||
mock_all_gather.side_effect = mock_all_gather_func1
|
||||
|
||||
def mock_npu_fused_infer_attention_score_func(query, k_nope, value,
|
||||
**common_kwargs):
|
||||
@@ -202,8 +215,9 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 8)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
||||
def test_compute_prefill_context(self, mock_npu_attention):
|
||||
def test_compute_prefill_context(self, mock_npu_attention, mock_pcp_group):
|
||||
|
||||
block_num = 100
|
||||
block_size = 128
|
||||
@@ -238,6 +252,13 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.impl._load_kv_for_chunk = MagicMock()
|
||||
self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_npu_attention.return_value = torch.randn(batch_size, num_heads,
|
||||
head_size), torch.randn(
|
||||
batch_size,
|
||||
@@ -666,6 +687,9 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
self.assertIsInstance(out_final, torch.Tensor)
|
||||
self.assertIsInstance(lse_final, torch.Tensor)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('torch.cat')
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('torch.distributed.all_gather')
|
||||
@@ -673,7 +697,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
@patch('torch.split')
|
||||
def test_update_chunk_attn_out_lse_dcp_pcp_both_greater_than_1(
|
||||
self, mock_split, mock_stack, mock_all_gather,
|
||||
mock_all_to_all_single, mock_cat):
|
||||
mock_all_to_all_single, mock_cat, mock_pcp, mock_get_pcp_group):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -687,6 +711,9 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
mock_stack.return_value = torch.randn(6, 2, 2, 9)
|
||||
mock_split.return_value = (torch.randn(6, 2, 2,
|
||||
8), torch.randn(6, 2, 2, 1))
|
||||
mock_pcp_group = MagicMock()
|
||||
mock_pcp_group.all_gather.return_value = torch.randn(6, 4, 9)
|
||||
mock_get_pcp_group.return_value = mock_pcp_group
|
||||
|
||||
# Call the method under test
|
||||
output, lse = self.impl._update_chunk_attn_out_lse(
|
||||
@@ -700,10 +727,10 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
self.assertEqual(mock_cat.call_count, 1)
|
||||
mock_all_to_all_single.assert_called_once()
|
||||
mock_stack.assert_called_once()
|
||||
mock_split.assert_called_once()
|
||||
self.assertEqual(mock_all_gather.call_count, 1)
|
||||
self.assertEqual(mock_get_pcp_group.call_count, 1)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('torch.cat')
|
||||
@patch('torch.chunk')
|
||||
@patch('torch.stack')
|
||||
@@ -712,7 +739,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
@patch('torch.distributed.all_gather')
|
||||
def test_update_chunk_attn_out_lse_dcp_greater_than_1_only(
|
||||
self, mock_all_gather, mock_all_to_all_single, mock_split,
|
||||
mock_stack, mock_chunk, mock_cat):
|
||||
mock_stack, mock_chunk, mock_cat, mock_pcp, mock_pcp_group):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -723,7 +750,8 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
# Mock output
|
||||
mock_cat.return_value = torch.randn(2, 4, 9)
|
||||
mock_all_to_all_single.return_value = torch.randn(2, 4, 9)
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
mock_chunk.return_value = [torch.randn(2, 2, 9), torch.randn(2, 2, 9)]
|
||||
mock_stack.return_value = torch.randn(2, 2, 2, 9)
|
||||
mock_split.return_value = [
|
||||
@@ -743,11 +771,10 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
self.assertEqual(mock_cat.call_count, 1)
|
||||
mock_all_to_all_single.assert_called_once()
|
||||
mock_chunk.assert_called_once()
|
||||
mock_stack.assert_called_once()
|
||||
mock_split.assert_called_once()
|
||||
mock_all_gather.assert_not_called()
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('torch.cat')
|
||||
@patch('torch.stack')
|
||||
@patch('torch.split')
|
||||
@@ -758,7 +785,8 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
)
|
||||
def test_update_chunk_attn_out_lse_pcp_greater_than_1_only(
|
||||
self, mock_update_out_and_lse, mock_all_gather,
|
||||
mock_all_to_all_single, mock_split, mock_stack, mock_cat):
|
||||
mock_all_to_all_single, mock_split, mock_stack, mock_cat, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -769,7 +797,9 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
# Mock output
|
||||
mock_cat.return_value = torch.randn(2, 4, 9)
|
||||
mock_all_gather.return_value = [(2, 4, 9), (2, 4, 9)]
|
||||
mock_pcp_group = MagicMock()
|
||||
mock_pcp_group.all_gather.return_value = torch.randn(4, 4, 9)
|
||||
mock_get_pcp_group.return_value = mock_pcp_group
|
||||
mock_stack.return_value = torch.randn(2, 2, 4, 9)
|
||||
mock_split.return_value = [
|
||||
torch.randn(2, 2, 4, 8),
|
||||
@@ -791,6 +821,4 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
self.assertEqual(mock_cat.call_count, 1)
|
||||
mock_all_to_all_single.assert_not_called()
|
||||
mock_stack.assert_called_once()
|
||||
mock_split.assert_called_once()
|
||||
mock_all_gather.assert_called_once()
|
||||
mock_get_pcp_group.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user