[Perf] vectorize PCP/DCP loops in mla_v1.py (#5003)
### What this PR does / why we need it?
- Replace nested PCP/DCP Python loops with fully vectorized tensor
operations
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: F.Liu <liufeng248@huawei.com>
Co-authored-by: F.Liu <liufeng248@huawei.com>
This commit is contained in:
@@ -451,10 +451,10 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.assertIsNone(decode_res)
|
self.assertIsNone(decode_res)
|
||||||
self.assertIsNotNone(prefill_res)
|
self.assertIsNotNone(prefill_res)
|
||||||
|
|
||||||
@patch("torch.distributed.all_gather")
|
@patch('vllm.distributed.parallel_state._PCP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("torch.distributed.all_to_all_single")
|
@patch("torch.distributed.all_to_all_single")
|
||||||
def test_process_attn_out_lse(self, mock_all_to_all_single,
|
def test_process_attn_out_lse(self, mock_all_to_all_single, mock_pcp):
|
||||||
mock_all_gather):
|
|
||||||
self.impl.dcp_size = 2
|
self.impl.dcp_size = 2
|
||||||
self.impl.pcp_size = 2
|
self.impl.pcp_size = 2
|
||||||
|
|
||||||
@@ -468,11 +468,10 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||||
input)
|
input)
|
||||||
|
|
||||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
def make_all_gather(ws):
|
||||||
tensor_list[0] = tensor
|
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||||
tensor_list[1] = tensor.clone()
|
|
||||||
|
|
||||||
mock_all_gather.side_effect = mock_all_gather_func
|
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))
|
||||||
|
|
||||||
decode_metadata = MagicMock()
|
decode_metadata = MagicMock()
|
||||||
decode_metadata.actual_seq_lengths_q = MagicMock()
|
decode_metadata.actual_seq_lengths_q = MagicMock()
|
||||||
@@ -483,11 +482,12 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
||||||
decode_metadata)
|
decode_metadata)
|
||||||
|
|
||||||
self.assertEqual(result[0].shape[0], B)
|
self.assertEqual(result.shape[0], B * self.impl.pcp_size)
|
||||||
self.assertEqual(result[0].shape[1], N / self.impl.dcp_size)
|
self.assertEqual(result.shape[1], N)
|
||||||
self.assertEqual(result[0].shape[2], self.impl.kv_lora_rank + 1)
|
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
||||||
|
|
||||||
@patch("torch.distributed.all_gather")
|
@patch('vllm.distributed.parallel_state._PCP',
|
||||||
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
@patch("torch.distributed.all_to_all_single")
|
@patch("torch.distributed.all_to_all_single")
|
||||||
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
|
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
|
||||||
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
||||||
@@ -495,7 +495,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
|
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
|
||||||
mock_npu_multi_head_latent_attention,
|
mock_npu_multi_head_latent_attention,
|
||||||
mock_get_forward_context,
|
mock_get_forward_context,
|
||||||
mock_all_to_all_single, mock_all_gather):
|
mock_all_to_all_single, mock_pcp):
|
||||||
self.impl.dcp_size = 2
|
self.impl.dcp_size = 2
|
||||||
self.impl.pcp_size = 2
|
self.impl.pcp_size = 2
|
||||||
self.impl.num_kv_heads = 1
|
self.impl.num_kv_heads = 1
|
||||||
@@ -534,11 +534,10 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||||
input)
|
input)
|
||||||
|
|
||||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
def make_all_gather(ws):
|
||||||
tensor_list[0] = tensor
|
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||||
tensor_list[1] = tensor.clone()
|
|
||||||
|
|
||||||
mock_all_gather.side_effect = mock_all_gather_func
|
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))
|
||||||
|
|
||||||
self.impl._v_up_proj = MagicMock()
|
self.impl._v_up_proj = MagicMock()
|
||||||
self.impl._v_up_proj.return_value = torch.randn(
|
self.impl._v_up_proj.return_value = torch.randn(
|
||||||
@@ -562,9 +561,6 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
def mock_all_gather(ws):
|
def mock_all_gather(ws):
|
||||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||||
|
|
||||||
mock_dcp.all_gather = MagicMock(side_effect=mock_all_gather(2))
|
|
||||||
mock_pcp.all_gather = MagicMock(side_effect=mock_all_gather(2))
|
|
||||||
|
|
||||||
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen,
|
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen,
|
||||||
head_num, kv_head_num, pre_out, prev_lse, qk_scale,
|
head_num, kv_head_num, pre_out, prev_lse, qk_scale,
|
||||||
kernel_type, mask_type, input_layout, calc_type,
|
kernel_type, mask_type, input_layout, calc_type,
|
||||||
@@ -624,6 +620,10 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
torch.ones(10, 10, dtype=torch.float16), 1)
|
torch.ones(10, 10, dtype=torch.float16), 1)
|
||||||
for test_case in test_cases:
|
for test_case in test_cases:
|
||||||
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
|
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
|
||||||
|
mock_dcp.all_gather = MagicMock(
|
||||||
|
side_effect=mock_all_gather(dcp_size))
|
||||||
|
mock_pcp.all_gather = MagicMock(
|
||||||
|
side_effect=mock_all_gather(pcp_size))
|
||||||
assert len(nums_tokens_per_rank) == len(nums_all_rank_context)
|
assert len(nums_tokens_per_rank) == len(nums_all_rank_context)
|
||||||
nums_context_per_rank = []
|
nums_context_per_rank = []
|
||||||
for num_all_rank_context in nums_all_rank_context:
|
for num_all_rank_context in nums_all_rank_context:
|
||||||
@@ -804,11 +804,10 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
attn_lse_split_cp[0])
|
attn_lse_split_cp[0])
|
||||||
|
|
||||||
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
|
mock_npu_attention_update.side_effect = mock_npu_attention_update_effect
|
||||||
attn_out_lse_list = [
|
attn_out_lse = torch.randn(self.impl.pcp_size * NUM_TOKENS,
|
||||||
torch.randn(NUM_TOKENS, num_heads, head_dim)
|
self.impl.dcp_size * num_heads,
|
||||||
for _ in range(self.impl.pcp_size * self.impl.dcp_size)
|
head_dim)
|
||||||
]
|
out = self.impl._npu_attention_update(attn_out_lse)
|
||||||
out = self.impl._npu_attention_update(attn_out_lse_list)
|
|
||||||
self.impl.dcp_size = 1
|
self.impl.dcp_size = 1
|
||||||
self.impl.pcp_size = 1
|
self.impl.pcp_size = 1
|
||||||
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
|
assert out.shape == (NUM_TOKENS, num_heads, self.impl.kv_lora_rank)
|
||||||
@@ -908,12 +907,14 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
mock_npu_ring_mla.reset_mock()
|
mock_npu_ring_mla.reset_mock()
|
||||||
|
|
||||||
@patch("torch.distributed.all_to_all_single")
|
@patch("torch.distributed.all_to_all_single")
|
||||||
@patch("torch.distributed.all_gather")
|
@patch('vllm.distributed.parallel_state._PCP',
|
||||||
def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_gather,
|
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||||
|
def test_process_attn_out_lse_with_dcp_pcp(self, mock_pcp,
|
||||||
mock_all_to_all):
|
mock_all_to_all):
|
||||||
B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8]
|
B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8]
|
||||||
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)]
|
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)]
|
||||||
for test_case in test_cases:
|
for test_case in test_cases:
|
||||||
|
print(test_case)
|
||||||
self.impl.dcp_size = test_case[0]
|
self.impl.dcp_size = test_case[0]
|
||||||
self.impl.pcp_size = test_case[1]
|
self.impl.pcp_size = test_case[1]
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -928,26 +929,17 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
|
|
||||||
mock_all_to_all.side_effect = mock_all_to_all_side_effect
|
mock_all_to_all.side_effect = mock_all_to_all_side_effect
|
||||||
|
|
||||||
def mock_all_gather_side_effect(tensor_list, tensor, group=None):
|
def mock_all_gather(ws):
|
||||||
for i in range(len(tensor_list)):
|
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||||
tensor_list[i].copy_(tensor)
|
|
||||||
|
|
||||||
mock_all_gather.side_effect = mock_all_gather_side_effect
|
mock_pcp.all_gather = MagicMock(
|
||||||
|
side_effect=mock_all_gather(self.impl.pcp_size))
|
||||||
|
|
||||||
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
||||||
decode_meta)
|
decode_meta)
|
||||||
|
# [PCP * S, DCP * H, D + 1]
|
||||||
self.assertIsInstance(result, list)
|
self.assertIsInstance(result, torch.Tensor)
|
||||||
if self.impl.dcp_size == 1 and self.impl.pcp_size == 1:
|
assert result.shape == (B * self.impl.pcp_size, H, D + 1)
|
||||||
self.assertEqual(len(result), 0)
|
|
||||||
else:
|
|
||||||
self.assertEqual(len(result),
|
|
||||||
self.impl.dcp_size * self.impl.pcp_size) # 4
|
|
||||||
|
|
||||||
for tensor in result:
|
|
||||||
self.assertEqual(tensor.dtype, torch.float32)
|
|
||||||
self.assertEqual(tensor.shape,
|
|
||||||
(B, H // self.impl.dcp_size, D + 1))
|
|
||||||
self.impl.dcp_size = 1
|
self.impl.dcp_size = 1
|
||||||
self.impl.pcp_size = 1
|
self.impl.pcp_size = 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import ClassVar, List, Optional, Tuple, TypeVar
|
from typing import ClassVar, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -1120,26 +1120,37 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
lse=softmax_lse)
|
lse=softmax_lse)
|
||||||
|
|
||||||
# Update out&lse
|
# Update out&lse
|
||||||
attn_out_lse_list = self._process_attn_out_lse(attn_output,
|
attn_out_lse = self._process_attn_out_lse(attn_output, softmax_lse,
|
||||||
softmax_lse,
|
decode_meta)
|
||||||
decode_meta)
|
attn_output = self._npu_attention_update(attn_out_lse)
|
||||||
attn_output = self._npu_attention_update(attn_out_lse_list)
|
|
||||||
return self._v_up_proj(attn_output)
|
return self._v_up_proj(attn_output)
|
||||||
|
|
||||||
def _npu_attention_update(
|
def _npu_attention_update(self,
|
||||||
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
|
attn_out_lse: torch.Tensor) -> torch.Tensor:
|
||||||
attn_out_split_cp = []
|
# [PCP * S, DCP * H, D+1]
|
||||||
attn_lse_split_cp = []
|
B_total, H_total, D_plus_1 = attn_out_lse.shape
|
||||||
|
S = B_total // self.pcp_size
|
||||||
for attn_out_lse in attn_out_lse_list:
|
H = H_total // self.dcp_size
|
||||||
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
|
D = self.kv_lora_rank
|
||||||
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
|
assert D_plus_1 == D + 1
|
||||||
attn_out_split_cp.append(attn_out_allgather)
|
# [PCP, S, DCP, H, D+1]
|
||||||
attn_lse_split_cp.append(attn_lse_allgather)
|
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
|
||||||
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
|
# [PCP, DCP, S, H, D+1]
|
||||||
attn_out_split_cp, 0)
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
|
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
|
||||||
self.kv_lora_rank)
|
x = x.view(-1, S, H, D_plus_1)
|
||||||
|
# Split out lse
|
||||||
|
out_flat, lse_flat = torch.split(x, [D, 1],
|
||||||
|
dim=-1) # [N, S, H, D], [N, S, H, 1]
|
||||||
|
# out: [N, S, H, D] -> [N, S*H, D]
|
||||||
|
# lse: [N, S, H, 1] -> [N, S*H]
|
||||||
|
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
|
||||||
|
lse_flat = lse_flat.flatten(1, -1) # [N, S*H]
|
||||||
|
# unbind to list
|
||||||
|
out_list = out_flat.unbind(0) # [S*H, D]
|
||||||
|
lse_list = lse_flat.unbind(0) # [S*H]
|
||||||
|
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
|
||||||
|
attn_out = attn_out.view(-1, H, D)
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
||||||
@@ -1155,8 +1166,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
attn_output: torch.Tensor,
|
attn_output: torch.Tensor,
|
||||||
softmax_lse: torch.Tensor,
|
softmax_lse: torch.Tensor,
|
||||||
decode_meta: AscendMLADecodeMetadata,
|
decode_meta: AscendMLADecodeMetadata,
|
||||||
) -> List[torch.Tensor]:
|
) -> torch.Tensor:
|
||||||
attn_out_lse_list = []
|
|
||||||
out_mask = decode_meta.batch_seq_mask[:, None,
|
out_mask = decode_meta.batch_seq_mask[:, None,
|
||||||
None].expand_as(attn_output)
|
None].expand_as(attn_output)
|
||||||
attn_output = torch.where(out_mask, 0, attn_output)
|
attn_output = torch.where(out_mask, 0, attn_output)
|
||||||
@@ -1175,30 +1185,14 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
|||||||
dist.all_to_all_single(attn_out_lse_all2all,
|
dist.all_to_all_single(attn_out_lse_all2all,
|
||||||
attn_out_lse,
|
attn_out_lse,
|
||||||
group=self.dcp_group)
|
group=self.dcp_group)
|
||||||
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
|
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
|
||||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
attn_out_lse = attn_out_lse_all2all.contiguous()
|
|
||||||
attn_out_lse_list = list(
|
|
||||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
|
||||||
|
|
||||||
if self.pcp_size > 1:
|
if self.pcp_size > 1:
|
||||||
# AllGather out&lse within PCP group
|
# AllGather out&lse within CP group
|
||||||
attn_out_lse_list = [
|
attn_out_lse = get_pcp_group().all_gather(
|
||||||
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
|
attn_out_lse.contiguous(), dim=0)
|
||||||
]
|
|
||||||
dist.all_gather(attn_out_lse_list,
|
|
||||||
attn_out_lse,
|
|
||||||
group=self.pcp_group)
|
|
||||||
if self.dcp_size > 1 and self.pcp_size > 1:
|
|
||||||
attn_out_lse_list_pcp_dcp = []
|
|
||||||
for s in attn_out_lse_list:
|
|
||||||
attn_out_lse_list_split = list(
|
|
||||||
torch.chunk(s, self.dcp_size, dim=1))
|
|
||||||
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
|
|
||||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
|
||||||
|
|
||||||
return attn_out_lse_list
|
return attn_out_lse
|
||||||
|
|
||||||
def _reorg_kvcache(
|
def _reorg_kvcache(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user