Revert "MLA prefill preformance optimization (#5275)" (#5410)

We'll release 0.13.0 soon. The main branch is freeze. Let's revert the
newest change and redo it once 0.13.0 is released
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
This commit is contained in:
wangxiyuan
2025-12-27 09:48:56 +08:00
committed by GitHub
parent 711f1861e4
commit d1f0df7b4b
4 changed files with 50 additions and 361 deletions

View File

@@ -3217,8 +3217,6 @@ class NPUModelRunner(GPUModelRunner):
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], []
split_with_q_head_nomask_idx_reqs = []
split_kv_with_q_tail_nomask_idx_reqs = []
chunk_seqlens = []
kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], []
q_req_offset = 0
@@ -3244,10 +3242,7 @@ class NPUModelRunner(GPUModelRunner):
(q_head_chunk_id + 1))))
kv_with_q_head_nomask_seqlens.append(chunk_len *
q_head_chunk_id)
split_with_q_head_nomask_idx_reqs.append(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_head_chunk_id)))
q_tail_idx.extend(
list(
range(q_req_offset + chunk_len,
@@ -3264,17 +3259,21 @@ class NPUModelRunner(GPUModelRunner):
(q_tail_chunk_id + 1))))
kv_with_q_tail_nomask_seqlens.append(chunk_len *
q_tail_chunk_id)
split_kv_with_q_tail_nomask_idx_reqs.append(
list(
range(kv_req_offset, kv_req_offset +
chunk_len * q_tail_chunk_id)))
q_req_offset += seq_len
kv_req_offset += seq_len * self.pcp_size
q_head_idx_tensor = self._list_to_tensor(
q_head_idx, self.device)
q_tail_idx_tensor = self._list_to_tensor(
q_tail_idx, self.device)
# Convert lists to tensors and move to device
def _list_to_tensor(lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst),
dtype=dtype,
device=device)
tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
non_blocking=True)
return tensor_npu
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device)
self.q_head_idx_tensor = q_head_idx_tensor
self.q_tail_idx_tensor = q_tail_idx_tensor
@@ -3292,7 +3291,7 @@ class NPUModelRunner(GPUModelRunner):
'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
}
for key, value in self.kv_idx_names.items():
tensor_npu = self._list_to_tensor(value, self.device)
tensor_npu = _list_to_tensor(value, self.device)
self.kv_idx_names[key] = tensor_npu
attn_mask_seqlens = torch.tensor(
@@ -3303,11 +3302,6 @@ class NPUModelRunner(GPUModelRunner):
tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
dtype=torch.int32)
if self.vllm_config.model_config.use_mla:
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = self._split_nomask_idx_tensor_list(
split_with_q_head_nomask_idx_reqs,
split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens, chunk_seqlens)
pcp_prefill_mask = self.attn_mask
self.extra_long_seq_kwargs = {
@@ -3338,99 +3332,9 @@ class NPUModelRunner(GPUModelRunner):
'tail_attn_nomask_seqlens']
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
'pcp_prefill_mask']
if self.vllm_config.model_config.use_mla:
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list
long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata
def _list_to_tensor(self, lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)
tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True)
return tensor_npu
def _split_nomask_idx_tensor_list(self, split_with_q_head_nomask_idx_reqs,
split_kv_with_q_tail_nomask_idx_reqs,
head_attn_nomask_seqlens, chunk_seqlens):
split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list= [], []
head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = [], []
if split_with_q_head_nomask_idx_reqs:
#In long-sequence scenarios, the computational cost and latency
#of the _npu_ring_mla operator are not proportional, so we split
#long sequences into shorter ones to improve performance.
split_size = 16 * 1024
if self.pcp_rank == 0:
split_q_head_nomask_idx_list = [
self.kv_idx_names['kv_with_q_head_nomask_idx_tensor']
]
else:
split_q_head_nomask_idx_list, split_q_head_nomask_lens_list = self._split_multi_batch_kv_idx(
split_with_q_head_nomask_idx_reqs, split_size)
split_q_tail_nomask_idx_list, split_q_tail_nomask_lens_list = self._split_multi_batch_kv_idx(
split_kv_with_q_tail_nomask_idx_reqs, split_size)
for q_head_nomask_idx in split_q_head_nomask_idx_list:
split_q_head_nomask_idx_tensor_list.append(
self._list_to_tensor(q_head_nomask_idx, self.device))
for q_tail_nomask_idx in split_q_tail_nomask_idx_list:
split_q_tail_nomask_idx_tensor_list.append(
self._list_to_tensor(q_tail_nomask_idx, self.device))
if self.pcp_rank == 0:
head_attn_nomask_seqlens_list = [head_attn_nomask_seqlens]
else:
for q_head_nomask_lens in split_q_head_nomask_lens_list:
head_attn_nomask_seqlens_list.append(
torch.tensor([chunk_seqlens, q_head_nomask_lens],
dtype=torch.int32))
for q_tail_nomask_lens in split_q_tail_nomask_lens_list:
tail_attn_nomask_seqlens_list.append(
torch.tensor([chunk_seqlens, q_tail_nomask_lens],
dtype=torch.int32))
return split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list
def _split_multi_batch_kv_idx(
self,
kv_nomask_idx_multi_batch,
split_size,
):
batch_lengths = [len(batch) for batch in kv_nomask_idx_multi_batch]
max_batch_length = max(batch_lengths) if batch_lengths else 0
time = (max_batch_length + split_size - 1) // split_size
split_kv_idx_3d = []
split_kv_len_2d = []
merged_split_kv_idx_3d = []
for single_batch in kv_nomask_idx_multi_batch:
current_batch_split = []
current_batch_len = []
for t in range(time):
start = t * split_size
current_segment = single_batch[start:start + split_size]
current_batch_split.append(current_segment)
current_batch_len.append(len(current_segment))
split_kv_idx_3d.append(current_batch_split)
split_kv_len_2d.append(current_batch_len)
for time_idx in range(time):
current_time_merged = []
for batch in split_kv_idx_3d:
current_time_merged.extend(batch[time_idx])
merged_split_kv_idx_3d.append(current_time_merged)
def reshape_kv_len_to_time_first(split_kv_len_2d):
if not split_kv_len_2d or not split_kv_len_2d[0]:
return []
return [[batch_len[time_idx] for batch_len in split_kv_len_2d]
for time_idx in range(len(split_kv_len_2d[0]))]
merged_split_kv_len_2d = reshape_kv_len_to_time_first(split_kv_len_2d)
return merged_split_kv_idx_3d, merged_split_kv_len_2d
def _generate_pcp_mtp_input(
self,
num_reqs: int,