[Perf] vectorize PCP/DCP loops in mla_v1.py (#5003)
### What this PR does / why we need it?
- Replace nested PCP/DCP Python loops with fully vectorized tensor
operations
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: F.Liu <liufeng248@huawei.com>
Co-authored-by: F.Liu <liufeng248@huawei.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import ClassVar, List, Optional, Tuple, TypeVar
|
||||
from typing import ClassVar, Optional, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -1120,26 +1120,37 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
lse=softmax_lse)
|
||||
|
||||
# Update out&lse
|
||||
attn_out_lse_list = self._process_attn_out_lse(attn_output,
|
||||
softmax_lse,
|
||||
decode_meta)
|
||||
attn_output = self._npu_attention_update(attn_out_lse_list)
|
||||
attn_out_lse = self._process_attn_out_lse(attn_output, softmax_lse,
|
||||
decode_meta)
|
||||
attn_output = self._npu_attention_update(attn_out_lse)
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
def _npu_attention_update(
|
||||
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
|
||||
attn_out_split_cp = []
|
||||
attn_lse_split_cp = []
|
||||
|
||||
for attn_out_lse in attn_out_lse_list:
|
||||
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
|
||||
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
|
||||
attn_out_split_cp.append(attn_out_allgather)
|
||||
attn_lse_split_cp.append(attn_lse_allgather)
|
||||
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
|
||||
attn_out_split_cp, 0)
|
||||
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
|
||||
self.kv_lora_rank)
|
||||
def _npu_attention_update(self,
|
||||
attn_out_lse: torch.Tensor) -> torch.Tensor:
|
||||
# [PCP * S, DCP * H, D+1]
|
||||
B_total, H_total, D_plus_1 = attn_out_lse.shape
|
||||
S = B_total // self.pcp_size
|
||||
H = H_total // self.dcp_size
|
||||
D = self.kv_lora_rank
|
||||
assert D_plus_1 == D + 1
|
||||
# [PCP, S, DCP, H, D+1]
|
||||
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
|
||||
# [PCP, DCP, S, H, D+1]
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
|
||||
x = x.view(-1, S, H, D_plus_1)
|
||||
# Split out lse
|
||||
out_flat, lse_flat = torch.split(x, [D, 1],
|
||||
dim=-1) # [N, S, H, D], [N, S, H, 1]
|
||||
# out: [N, S, H, D] -> [N, S*H, D]
|
||||
# lse: [N, S, H, 1] -> [N, S*H]
|
||||
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
|
||||
lse_flat = lse_flat.flatten(1, -1) # [N, S*H]
|
||||
# unbind to list
|
||||
out_list = out_flat.unbind(0) # [S*H, D]
|
||||
lse_list = lse_flat.unbind(0) # [S*H]
|
||||
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
|
||||
attn_out = attn_out.view(-1, H, D)
|
||||
return attn_out
|
||||
|
||||
def _out_lse_reshape(self, attn_out: torch.Tensor,
|
||||
@@ -1155,8 +1166,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
attn_output: torch.Tensor,
|
||||
softmax_lse: torch.Tensor,
|
||||
decode_meta: AscendMLADecodeMetadata,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_out_lse_list = []
|
||||
) -> torch.Tensor:
|
||||
out_mask = decode_meta.batch_seq_mask[:, None,
|
||||
None].expand_as(attn_output)
|
||||
attn_output = torch.where(out_mask, 0, attn_output)
|
||||
@@ -1175,30 +1185,14 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
dist.all_to_all_single(attn_out_lse_all2all,
|
||||
attn_out_lse,
|
||||
group=self.dcp_group)
|
||||
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
|
||||
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
|
||||
if self.pcp_size > 1:
|
||||
attn_out_lse = attn_out_lse_all2all.contiguous()
|
||||
attn_out_lse_list = list(
|
||||
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
|
||||
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])
|
||||
|
||||
if self.pcp_size > 1:
|
||||
# AllGather out&lse within PCP group
|
||||
attn_out_lse_list = [
|
||||
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
|
||||
]
|
||||
dist.all_gather(attn_out_lse_list,
|
||||
attn_out_lse,
|
||||
group=self.pcp_group)
|
||||
if self.dcp_size > 1 and self.pcp_size > 1:
|
||||
attn_out_lse_list_pcp_dcp = []
|
||||
for s in attn_out_lse_list:
|
||||
attn_out_lse_list_split = list(
|
||||
torch.chunk(s, self.dcp_size, dim=1))
|
||||
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
|
||||
attn_out_lse_list = attn_out_lse_list_pcp_dcp
|
||||
# AllGather out&lse within CP group
|
||||
attn_out_lse = get_pcp_group().all_gather(
|
||||
attn_out_lse.contiguous(), dim=0)
|
||||
|
||||
return attn_out_lse_list
|
||||
return attn_out_lse
|
||||
|
||||
def _reorg_kvcache(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user