MLA prefill preformance optimization (#5456)
### What this PR does / why we need it?
Since the _npu_ring_mla operator deteriorates in long-sequencescenarios,
the long sequence is split into shorter sequences for input to improve
performance.
- vLLM version: v0.13.0
- vLLM main:
5326c89803
---------
Signed-off-by: pichangping <1337510399@qq.com>
This commit is contained in:
@@ -465,11 +465,18 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
return output
|
||||
|
||||
def _attention_with_mask_and_nomask(
|
||||
self, q_nope: torch.Tensor, q_pe: torch.Tensor,
|
||||
k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor,
|
||||
kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor,
|
||||
attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor,
|
||||
mask: torch.Tensor):
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
k_nope: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_mask_idx: torch.Tensor,
|
||||
kv_nomask_idx: list[torch.Tensor],
|
||||
attn_mask_seqlens: torch.Tensor,
|
||||
attn_nomask_seqlens: list[torch.Tensor],
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
attn_output = torch.empty(q_nope.shape[0],
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
@@ -503,30 +510,33 @@ class AscendMlaCPImpl(AscendMLAImpl):
|
||||
softmax_lse=attn_lse)
|
||||
|
||||
# nomask
|
||||
if kv_nomask_idx.shape[0] == 0:
|
||||
if not kv_nomask_idx or len(kv_nomask_idx[0]) == 0:
|
||||
return attn_output, attn_lse
|
||||
|
||||
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx)
|
||||
value_nomask = torch.index_select(value, 0, kv_nomask_idx)
|
||||
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx)
|
||||
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope_nomask,
|
||||
k_rope=k_pe_nomask,
|
||||
value=value_nomask,
|
||||
mask=mask,
|
||||
seqlen=attn_nomask_seqlens,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=attn_output,
|
||||
prev_lse=attn_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
for kv_nomask_idx_split, attn_nomask_seqlens_split in zip(
|
||||
kv_nomask_idx, attn_nomask_seqlens):
|
||||
k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx_split)
|
||||
value_nomask = torch.index_select(value, 0, kv_nomask_idx_split)
|
||||
k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx_split)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope_nomask,
|
||||
k_rope=k_pe_nomask,
|
||||
value=value_nomask,
|
||||
mask=mask,
|
||||
seqlen=attn_nomask_seqlens_split,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=attn_output,
|
||||
prev_lse=attn_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
return attn_output, attn_lse
|
||||
|
||||
def _forward_decode(
|
||||
|
||||
@@ -565,6 +565,8 @@ class PCPManager:
|
||||
q_head_idx, q_tail_idx = [], []
|
||||
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
|
||||
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
|
||||
split_with_q_head_nomask_idx_reqs = []
|
||||
split_kv_with_q_tail_nomask_idx_reqs = []
|
||||
chunk_seqlens = []
|
||||
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
|
||||
q_req_offset = 0
|
||||
@@ -590,7 +592,10 @@ class PCPManager:
|
||||
(q_head_chunk_id + 1))))
|
||||
kv_with_q_head_nomask_seqlens.append(chunk_len *
|
||||
q_head_chunk_id)
|
||||
|
||||
split_with_q_head_nomask_idx_reqs.append(
|
||||
list(
|
||||
range(kv_req_offset, kv_req_offset +
|
||||
chunk_len * q_head_chunk_id)))
|
||||
q_tail_idx.extend(
|
||||
list(
|
||||
range(q_req_offset + chunk_len,
|
||||
@@ -607,21 +612,17 @@ class PCPManager:
|
||||
(q_tail_chunk_id + 1))))
|
||||
kv_with_q_tail_nomask_seqlens.append(chunk_len *
|
||||
q_tail_chunk_id)
|
||||
|
||||
split_kv_with_q_tail_nomask_idx_reqs.append(
|
||||
list(
|
||||
range(kv_req_offset, kv_req_offset +
|
||||
chunk_len * q_tail_chunk_id)))
|
||||
q_req_offset += seq_len
|
||||
kv_req_offset += seq_len * self.pcp_world_size
|
||||
|
||||
# Convert lists to tensors and move to device
|
||||
def _list_to_tensor(lst, device, dtype=torch.int32):
|
||||
tensor_npu = torch.zeros(len(lst),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
|
||||
non_blocking=True)
|
||||
return tensor_npu
|
||||
|
||||
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
|
||||
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
|
||||
q_head_idx_tensor = self._list_to_tensor(
|
||||
q_head_idx, self.device)
|
||||
q_tail_idx_tensor = self._list_to_tensor(
|
||||
q_tail_idx, self.device)
|
||||
self.q_head_idx_tensor = q_head_idx_tensor
|
||||
self.q_tail_idx_tensor = q_tail_idx_tensor
|
||||
|
||||
@@ -639,7 +640,7 @@ class PCPManager:
|
||||
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
|
||||
}
|
||||
for key, value in self.kv_idx_names.items():
|
||||
tensor_npu = _list_to_tensor(value, self.device)
|
||||
tensor_npu = self._list_to_tensor(value, self.device)
|
||||
self.kv_idx_names[key] = tensor_npu
|
||||
|
||||
attn_mask_seqlens = torch.tensor(
|
||||
@@ -650,6 +651,11 @@ class PCPManager:
|
||||
tail_attn_nomask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
|
||||
dtype=torch.int32)
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = self._split_nomask_idx_tensor_list(
|
||||
split_with_q_head_nomask_idx_reqs,
|
||||
split_kv_with_q_tail_nomask_idx_reqs,
|
||||
head_attn_nomask_seqlens, chunk_seqlens)
|
||||
pcp_prefill_mask = attn_mask
|
||||
|
||||
self.extra_long_seq_kwargs = {
|
||||
@@ -680,5 +686,95 @@ class PCPManager:
|
||||
'tail_attn_nomask_seqlens']
|
||||
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
|
||||
'pcp_prefill_mask']
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
|
||||
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
|
||||
long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list
|
||||
long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list
|
||||
self.long_seq_metadata = long_seq_metadata
|
||||
return long_seq_metadata
|
||||
|
||||
def _list_to_tensor(self, lst, device, dtype=torch.int32):
|
||||
tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)
|
||||
tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True)
|
||||
return tensor_npu
|
||||
|
||||
def _split_nomask_idx_tensor_list(self, split_with_q_head_nomask_idx_reqs,
|
||||
split_kv_with_q_tail_nomask_idx_reqs,
|
||||
head_attn_nomask_seqlens, chunk_seqlens):
|
||||
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list= [], []
|
||||
head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = [], []
|
||||
if split_with_q_head_nomask_idx_reqs:
|
||||
#In long-sequence scenarios, the computational cost and latency
|
||||
#of the _npu_ring_mla operator are not proportional, so we split
|
||||
#long sequences into shorter ones to improve performance.
|
||||
split_size = 16 * 1024
|
||||
if self.pcp_world_rank == 0:
|
||||
split_q_head_nomask_idx_list = [
|
||||
self.kv_idx_names['kv_with_q_head_nomask_idx_tensor']
|
||||
]
|
||||
else:
|
||||
split_q_head_nomask_idx_list, split_q_head_nomask_lens_list = self._split_multi_batch_kv_idx(
|
||||
split_with_q_head_nomask_idx_reqs, split_size)
|
||||
split_q_tail_nomask_idx_list, split_q_tail_nomask_lens_list = self._split_multi_batch_kv_idx(
|
||||
split_kv_with_q_tail_nomask_idx_reqs, split_size)
|
||||
|
||||
for q_head_nomask_idx in split_q_head_nomask_idx_list:
|
||||
split_q_head_nomask_idx_tensor_list.append(
|
||||
self._list_to_tensor(q_head_nomask_idx, self.device))
|
||||
|
||||
for q_tail_nomask_idx in split_q_tail_nomask_idx_list:
|
||||
split_q_tail_nomask_idx_tensor_list.append(
|
||||
self._list_to_tensor(q_tail_nomask_idx, self.device))
|
||||
|
||||
if self.pcp_world_rank == 0:
|
||||
head_attn_nomask_seqlens_list = [head_attn_nomask_seqlens]
|
||||
else:
|
||||
for q_head_nomask_lens in split_q_head_nomask_lens_list:
|
||||
head_attn_nomask_seqlens_list.append(
|
||||
torch.tensor([chunk_seqlens, q_head_nomask_lens],
|
||||
dtype=torch.int32))
|
||||
for q_tail_nomask_lens in split_q_tail_nomask_lens_list:
|
||||
tail_attn_nomask_seqlens_list.append(
|
||||
torch.tensor([chunk_seqlens, q_tail_nomask_lens],
|
||||
dtype=torch.int32))
|
||||
return split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list
|
||||
|
||||
def _split_multi_batch_kv_idx(
|
||||
self,
|
||||
kv_nomask_idx_multi_batch,
|
||||
split_size,
|
||||
):
|
||||
batch_lengths = [len(batch) for batch in kv_nomask_idx_multi_batch]
|
||||
max_batch_length = max(batch_lengths) if batch_lengths else 0
|
||||
time = (max_batch_length + split_size - 1) // split_size
|
||||
split_kv_idx_3d = []
|
||||
split_kv_len_2d = []
|
||||
merged_split_kv_idx_3d = []
|
||||
|
||||
for single_batch in kv_nomask_idx_multi_batch:
|
||||
current_batch_split = []
|
||||
current_batch_len = []
|
||||
for t in range(time):
|
||||
start = t * split_size
|
||||
current_segment = single_batch[start:start + split_size]
|
||||
current_batch_split.append(current_segment)
|
||||
current_batch_len.append(len(current_segment))
|
||||
|
||||
split_kv_idx_3d.append(current_batch_split)
|
||||
split_kv_len_2d.append(current_batch_len)
|
||||
|
||||
for time_idx in range(time):
|
||||
current_time_merged = []
|
||||
for batch in split_kv_idx_3d:
|
||||
current_time_merged.extend(batch[time_idx])
|
||||
merged_split_kv_idx_3d.append(current_time_merged)
|
||||
|
||||
def reshape_kv_len_to_time_first(split_kv_len_2d):
|
||||
if not split_kv_len_2d or not split_kv_len_2d[0]:
|
||||
return []
|
||||
return [[batch_len[time_idx] for batch_len in split_kv_len_2d]
|
||||
for time_idx in range(len(split_kv_len_2d[0]))]
|
||||
|
||||
merged_split_kv_len_2d = reshape_kv_len_to_time_first(split_kv_len_2d)
|
||||
return merged_split_kv_idx_3d, merged_split_kv_len_2d
|
||||
|
||||
Reference in New Issue
Block a user