diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index b486952c..a7597af8 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -865,7 +865,7 @@ class TestAscendMLAImpl(TestBase): q_head_idx, q_tail_idx, kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx, kv_with_q_tail_nomask_idx, \ kv_with_q_tail_mask_idx, chunk_seqlens, kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = get_pcp_split_info( rank, pcp_size, nums_tokens_per_rank) - kv_with_q_head_nomask_idx = [kv_with_q_head_nomask_idx] + output_head, lse_head = self.impl._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_head_idx), q_pe=torch.index_select(q_pe, 0, q_head_idx), @@ -876,16 +876,15 @@ class TestAscendMLAImpl(TestBase): kv_nomask_idx=kv_with_q_head_nomask_idx, attn_mask_seqlens=torch.tensor( [chunk_seqlens, chunk_seqlens], dtype=torch.int32), - attn_nomask_seqlens=[kv_with_q_head_nomask_seqlens], + attn_nomask_seqlens=kv_with_q_head_nomask_seqlens, mask=mask) self.assertEqual(output_head.shape, (q_head_idx.shape[0], num_heads, v_head_dim)) self.assertEqual(lse_head.shape, (num_heads, q_head_idx.shape[0])) self.assertEqual(mock_npu_ring_mla.call_count, - 1 + (len(kv_with_q_head_nomask_idx[0]) != 0)) + 1 + (kv_with_q_head_nomask_idx.shape[0] != 0)) mock_npu_ring_mla.reset_mock() - kv_with_q_tail_nomask_idx = [kv_with_q_tail_nomask_idx] output_tail, lse_tail = self.impl._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_tail_idx), q_pe=torch.index_select(q_pe, 0, q_tail_idx), @@ -896,7 +895,7 @@ class TestAscendMLAImpl(TestBase): kv_nomask_idx=kv_with_q_tail_nomask_idx, attn_mask_seqlens=torch.tensor( [chunk_seqlens, chunk_seqlens], dtype=torch.int32), - attn_nomask_seqlens=[kv_with_q_tail_nomask_seqlens], + attn_nomask_seqlens=kv_with_q_tail_nomask_seqlens, mask=mask) self.assertEqual(output_tail.shape, @@ -904,7 +903,7 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(lse_tail.shape, (num_heads, q_tail_idx.shape[0])) self.assertEqual(mock_npu_ring_mla.call_count, - 1 + (len(kv_with_q_tail_nomask_idx[0]) != 0)) + 1 + (kv_with_q_tail_nomask_idx.shape[0] != 0)) mock_npu_ring_mla.reset_mock() @patch("torch.distributed.all_to_all_single") diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py index 6681c79a..8ff26a6f 100644 --- a/tests/ut/worker/test_model_runner_v1.py +++ b/tests/ut/worker/test_model_runner_v1.py @@ -73,15 +73,6 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, mock_runner.query_lens = torch.tensor(query_lens) - mock_runner._get_cp_local_seq_lens.side_effect = NPUModelRunner._get_cp_local_seq_lens.__get__( - mock_runner, NPUModelRunner) - mock_runner._list_to_tensor.side_effect = NPUModelRunner._list_to_tensor.__get__( - mock_runner, NPUModelRunner) - mock_runner._split_nomask_idx_tensor_list.side_effect = NPUModelRunner._split_nomask_idx_tensor_list.__get__( - mock_runner, NPUModelRunner) - mock_runner._split_multi_batch_kv_idx.side_effect = NPUModelRunner._split_multi_batch_kv_idx.__get__( - mock_runner, NPUModelRunner) - mock_runner._get_cp_local_seq_lens = NPUModelRunner._get_cp_local_seq_lens.__get__( mock_runner, NPUModelRunner) @@ -97,7 +88,9 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, mock_runner.q_tail_idx_tensor = None mock_runner.q_full_idx = None - result = NPUModelRunner._generate_pcp_metadata(mock_runner, total_tokens) + method = NPUModelRunner._generate_pcp_metadata.__get__( + mock_runner, NPUModelRunner) + result = method(total_tokens) if not expect_not_none: assert result is None, f"Expected to return None, but got {type(result)}" @@ -478,201 +471,3 @@ def test_generate_pcp_mtp_input( target_input_ids_pcp_full) assert torch.equal(mock_runner.query_start_loc_pcp_full.cpu[:num_reqs + 1], target_query_start_loc_pcp_full) - - -@pytest.mark.parametrize( - "pcp_rank, split_with_q_head_nomask_idx_reqs, split_kv_with_q_tail_nomask_idx_reqs," - "head_attn_nomask_seqlens, chunk_seqlens," - "target_split_q_head, target_split_q_tail, target_head_seqlens, target_tail_seqlens", - [ - # case1: pcp_rank=0 - (0, [[10, 20, 30]], [[40, 50, 60]], - torch.tensor([[64], [0]], dtype=torch.int32), [64], [ - torch.tensor([1, 2, 3], dtype=torch.int32) - ], [torch.tensor([40, 50, 60], dtype=torch.int32)], [ - torch.tensor([[64], [0]], dtype=torch.int32) - ], [torch.tensor([[64], [3]], dtype=torch.int32)]), - # case2: pcp_rank=1 - (1, [[1, 2], [3, 4, 5]], [[6, 7], [8, 9, 10]], - torch.tensor([[128, 128], [128, 128]], dtype=torch.int32), [128, 128], - [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)], [ - torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32) - ], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32) - ], [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)]), - # case3: pcp_rank=2 - (2, [[11, 12, 13, 14], [15, 16]], [[17, 18, 19], [20, 21, 22, 23]], - torch.tensor([[256, 256], [512, 512]], dtype=torch.int32), [256, 256], - [torch.tensor([11, 12, 13, 14, 15, 16], dtype=torch.int32)], [ - torch.tensor([17, 18, 19, 20, 21, 22, 23], dtype=torch.int32) - ], [torch.tensor([[256, 256], [4, 2]], dtype=torch.int32) - ], [torch.tensor([[256, 256], [3, 4]], dtype=torch.int32)]), - # case4: empty input - ( - 0, - [], - [], - torch.tensor([], dtype=torch.int32).reshape(2, 0), - [], - [], - [], - [], - [], - ), - # case5: single element input - ( - 0, - [[10]], - [[40]], - torch.tensor([[64], [0]], dtype=torch.int32), - [64], - [torch.tensor([1, 2, 3], dtype=torch.int32)], - [torch.tensor([40], dtype=torch.int32)], - [torch.tensor([[64], [0]], dtype=torch.int32)], - [torch.tensor([[64], [1]], dtype=torch.int32)], - ), - # case6: pcp_rank=3 - ( - 3, - [[1, 2], [3, 4, 5]], - [[6, 7], [8, 9, 10]], - torch.tensor([[128, 128], [128, 128]], dtype=torch.int32), - [128, 128], - [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int32)], - [torch.tensor([6, 7, 8, 9, 10], dtype=torch.int32)], - [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)], - [torch.tensor([[128, 128], [2, 3]], dtype=torch.int32)], - ), - ]) -def test_split_nomask_idx_tensor_list( - pcp_rank, split_with_q_head_nomask_idx_reqs, - split_kv_with_q_tail_nomask_idx_reqs, head_attn_nomask_seqlens, - chunk_seqlens, target_split_q_head, target_split_q_tail, - target_head_seqlens, target_tail_seqlens): - # Mock input data - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.device = "cpu" - mock_runner.pcp_rank = 0 - mock_runner.kv_idx_names = { - "kv_with_q_head_nomask_idx_tensor": - torch.tensor([1, 2, 3], dtype=torch.int32) - } - - mock_runner.pcp_rank = pcp_rank - - # Mock output - mock_runner._split_multi_batch_kv_idx.side_effect = NPUModelRunner._split_multi_batch_kv_idx.__get__( - mock_runner, NPUModelRunner) - mock_runner._list_to_tensor.side_effect = NPUModelRunner._list_to_tensor.__get__( - mock_runner, NPUModelRunner) - - # Call the method under test - result = NPUModelRunner._split_nomask_idx_tensor_list( - mock_runner, - split_with_q_head_nomask_idx_reqs=split_with_q_head_nomask_idx_reqs, - split_kv_with_q_tail_nomask_idx_reqs= - split_kv_with_q_tail_nomask_idx_reqs, - head_attn_nomask_seqlens=head_attn_nomask_seqlens, - chunk_seqlens=chunk_seqlens) - split_q_head, split_q_tail, head_seqlens, tail_seqlens = result - - # Assert the method call - assert len(split_q_head) == len(target_split_q_head) - for res, target in zip(split_q_head, target_split_q_head): - assert torch.equal(res, target) - - assert len(split_q_tail) == len(target_split_q_tail) - for res, target in zip(split_q_tail, target_split_q_tail): - assert torch.equal(res, target) - - assert len(head_seqlens) == len(target_head_seqlens) - for res, target in zip(head_seqlens, target_head_seqlens): - if isinstance(target, torch.Tensor): - assert torch.equal(res, target) - else: - assert res == target - - assert len(tail_seqlens) == len(target_tail_seqlens) - for res, target in zip(tail_seqlens, target_tail_seqlens): - if isinstance(target, torch.Tensor): - assert torch.equal(res, target) - else: - assert res == target - - -@pytest.mark.parametrize( - "kv_nomask_idx_multi_batch, split_size, expected_merged_idx, expected_merged_len", - [ - # case1: multiple batches + split size greater than batch length - ( - [[0, 1, 2, 3, 4], [5, 6, 7]], - 2, - # expected merged_split_kv_idx_3d - [[0, 1, 5, 6], [2, 3, 7], [4]], - # expected merged_split_kv_len_2d - [[2, 2], [2, 1], [1, 0]], - ), - # case2: single batch + split size greater than batch length - ( - [[0, 1, 2]], - 5, - [[0, 1, 2]], - [[3]], - ), - # case3: split size equals maximum batch length - ( - [[0, 1, 2, 3], [5, 6]], - 4, - [[0, 1, 2, 3, 5, 6]], - [[4, 2]], - ), - # case4: Split size is 1 (minimum granularity split) - ( - [[0, 1], [2]], - 1, - [[0, 2], [1]], - [[1, 1], [1, 0]], - ), - # case6: the batch contains an empty list - ( - [[], [0, 1], [2]], - 1, - [[0, 2], [1]], - [[0, 1, 1], [0, 1, 0]], - ), - # case: empty input - ( - [], - 2, - [], - [], - ), - ]) -def test_split_multi_batch_kv_idx( - kv_nomask_idx_multi_batch, - split_size, - expected_merged_idx, - expected_merged_len, -): - # Mock input data - model_runner = MagicMock(spec=NPUModelRunner) - - # Call the method under test - result = NPUModelRunner._split_multi_batch_kv_idx( - self=model_runner, - kv_nomask_idx_multi_batch=kv_nomask_idx_multi_batch, - split_size=split_size) - - merged_split_kv_idx_3d, merged_split_kv_len_2d = result - - # Assert the method call - assert len(merged_split_kv_idx_3d) == len(expected_merged_idx) - - for t, (actual_seg, expected_seg) in enumerate( - zip(merged_split_kv_idx_3d, expected_merged_idx)): - assert actual_seg == expected_seg - - assert len(merged_split_kv_len_2d) == len(expected_merged_len) - - for t, (actual_len, expected_len) in enumerate( - zip(merged_split_kv_len_2d, expected_merged_len)): - assert actual_len == expected_len diff --git a/vllm_ascend/attention/mla_cp.py b/vllm_ascend/attention/mla_cp.py index b8dea74b..0a3aed14 100644 --- a/vllm_ascend/attention/mla_cp.py +++ b/vllm_ascend/attention/mla_cp.py @@ -778,18 +778,11 @@ class AscendMlaCPImpl(AscendMLAImpl): return output def _attention_with_mask_and_nomask( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - k_nope: torch.Tensor, - k_pe: torch.Tensor, - value: torch.Tensor, - kv_mask_idx: torch.Tensor, - kv_nomask_idx: list[torch.Tensor], - attn_mask_seqlens: torch.Tensor, - attn_nomask_seqlens: list[torch.Tensor], - mask: torch.Tensor, - ): + self, q_nope: torch.Tensor, q_pe: torch.Tensor, + k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor, + kv_mask_idx: torch.Tensor, kv_nomask_idx: torch.Tensor, + attn_mask_seqlens: torch.Tensor, attn_nomask_seqlens: torch.Tensor, + mask: torch.Tensor): attn_output = torch.empty(q_nope.shape[0], self.num_heads, self.v_head_dim, @@ -823,32 +816,30 @@ class AscendMlaCPImpl(AscendMLAImpl): softmax_lse=attn_lse) # nomask - if not kv_nomask_idx or len(kv_nomask_idx[0]) == 0: + if kv_nomask_idx.shape[0] == 0: return attn_output, attn_lse - for kv_nomask_idx_split, attn_nomask_seqlens_split in zip( - kv_nomask_idx, attn_nomask_seqlens): - k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx_split) - value_nomask = torch.index_select(value, 0, kv_nomask_idx_split) - k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx_split) - torch_npu.atb.npu_ring_mla( - q_nope=q_nope, - q_rope=q_pe, - k_nope=k_nope_nomask, - k_rope=k_pe_nomask, - value=value_nomask, - mask=mask, - seqlen=attn_nomask_seqlens_split, - head_num=self.num_heads, - kv_head_num=self.num_heads, - pre_out=attn_output, - prev_lse=attn_lse, - qk_scale=self.scale, - kernel_type="kernel_type_high_precision", - mask_type="no_mask", - input_layout="type_bsnd", - calc_type="calc_type_default", - output=attn_output, - softmax_lse=attn_lse) + + k_nope_nomask = torch.index_select(k_nope, 0, kv_nomask_idx) + value_nomask = torch.index_select(value, 0, kv_nomask_idx) + k_pe_nomask = torch.index_select(k_pe, 0, kv_nomask_idx) + torch_npu.atb.npu_ring_mla(q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope_nomask, + k_rope=k_pe_nomask, + value=value_nomask, + mask=mask, + seqlen=attn_nomask_seqlens, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=attn_output, + prev_lse=attn_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=attn_output, + softmax_lse=attn_lse) return attn_output, attn_lse def _forward_decode_pcp_dcp( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8832972b..b01c8f4d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3217,8 +3217,6 @@ class NPUModelRunner(GPUModelRunner): q_head_idx, q_tail_idx = [], [] kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] - split_with_q_head_nomask_idx_reqs = [] - split_kv_with_q_tail_nomask_idx_reqs = [] chunk_seqlens = [] kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] q_req_offset = 0 @@ -3244,10 +3242,7 @@ class NPUModelRunner(GPUModelRunner): (q_head_chunk_id + 1)))) kv_with_q_head_nomask_seqlens.append(chunk_len * q_head_chunk_id) - split_with_q_head_nomask_idx_reqs.append( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_head_chunk_id))) + q_tail_idx.extend( list( range(q_req_offset + chunk_len, @@ -3264,17 +3259,21 @@ class NPUModelRunner(GPUModelRunner): (q_tail_chunk_id + 1)))) kv_with_q_tail_nomask_seqlens.append(chunk_len * q_tail_chunk_id) - split_kv_with_q_tail_nomask_idx_reqs.append( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_tail_chunk_id))) + q_req_offset += seq_len kv_req_offset += seq_len * self.pcp_size - q_head_idx_tensor = self._list_to_tensor( - q_head_idx, self.device) - q_tail_idx_tensor = self._list_to_tensor( - q_tail_idx, self.device) + # Convert lists to tensors and move to device + def _list_to_tensor(lst, device, dtype=torch.int32): + tensor_npu = torch.zeros(len(lst), + dtype=dtype, + device=device) + tensor_npu.copy_(torch.tensor(lst, dtype=dtype), + non_blocking=True) + return tensor_npu + + q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) self.q_head_idx_tensor = q_head_idx_tensor self.q_tail_idx_tensor = q_tail_idx_tensor @@ -3292,7 +3291,7 @@ class NPUModelRunner(GPUModelRunner): 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx } for key, value in self.kv_idx_names.items(): - tensor_npu = self._list_to_tensor(value, self.device) + tensor_npu = _list_to_tensor(value, self.device) self.kv_idx_names[key] = tensor_npu attn_mask_seqlens = torch.tensor( @@ -3303,11 +3302,6 @@ class NPUModelRunner(GPUModelRunner): tail_attn_nomask_seqlens = torch.tensor( [chunk_seqlens, kv_with_q_tail_nomask_seqlens], dtype=torch.int32) - if self.vllm_config.model_config.use_mla: - split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = self._split_nomask_idx_tensor_list( - split_with_q_head_nomask_idx_reqs, - split_kv_with_q_tail_nomask_idx_reqs, - head_attn_nomask_seqlens, chunk_seqlens) pcp_prefill_mask = self.attn_mask self.extra_long_seq_kwargs = { @@ -3338,99 +3332,9 @@ class NPUModelRunner(GPUModelRunner): 'tail_attn_nomask_seqlens'] long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ 'pcp_prefill_mask'] - if self.vllm_config.model_config.use_mla: - long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list - long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list - long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list - long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list self.long_seq_metadata = long_seq_metadata return long_seq_metadata - def _list_to_tensor(self, lst, device, dtype=torch.int32): - tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device) - tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True) - return tensor_npu - - def _split_nomask_idx_tensor_list(self, split_with_q_head_nomask_idx_reqs, - split_kv_with_q_tail_nomask_idx_reqs, - head_attn_nomask_seqlens, chunk_seqlens): - split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list= [], [] - head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list = [], [] - if split_with_q_head_nomask_idx_reqs: - #In long-sequence scenarios, the computational cost and latency - #of the _npu_ring_mla operator are not proportional, so we split - #long sequences into shorter ones to improve performance. - split_size = 16 * 1024 - if self.pcp_rank == 0: - split_q_head_nomask_idx_list = [ - self.kv_idx_names['kv_with_q_head_nomask_idx_tensor'] - ] - else: - split_q_head_nomask_idx_list, split_q_head_nomask_lens_list = self._split_multi_batch_kv_idx( - split_with_q_head_nomask_idx_reqs, split_size) - split_q_tail_nomask_idx_list, split_q_tail_nomask_lens_list = self._split_multi_batch_kv_idx( - split_kv_with_q_tail_nomask_idx_reqs, split_size) - - for q_head_nomask_idx in split_q_head_nomask_idx_list: - split_q_head_nomask_idx_tensor_list.append( - self._list_to_tensor(q_head_nomask_idx, self.device)) - - for q_tail_nomask_idx in split_q_tail_nomask_idx_list: - split_q_tail_nomask_idx_tensor_list.append( - self._list_to_tensor(q_tail_nomask_idx, self.device)) - - if self.pcp_rank == 0: - head_attn_nomask_seqlens_list = [head_attn_nomask_seqlens] - else: - for q_head_nomask_lens in split_q_head_nomask_lens_list: - head_attn_nomask_seqlens_list.append( - torch.tensor([chunk_seqlens, q_head_nomask_lens], - dtype=torch.int32)) - for q_tail_nomask_lens in split_q_tail_nomask_lens_list: - tail_attn_nomask_seqlens_list.append( - torch.tensor([chunk_seqlens, q_tail_nomask_lens], - dtype=torch.int32)) - return split_q_head_nomask_idx_tensor_list, split_q_tail_nomask_idx_tensor_list, head_attn_nomask_seqlens_list, tail_attn_nomask_seqlens_list - - def _split_multi_batch_kv_idx( - self, - kv_nomask_idx_multi_batch, - split_size, - ): - batch_lengths = [len(batch) for batch in kv_nomask_idx_multi_batch] - max_batch_length = max(batch_lengths) if batch_lengths else 0 - time = (max_batch_length + split_size - 1) // split_size - split_kv_idx_3d = [] - split_kv_len_2d = [] - merged_split_kv_idx_3d = [] - - for single_batch in kv_nomask_idx_multi_batch: - current_batch_split = [] - current_batch_len = [] - for t in range(time): - start = t * split_size - current_segment = single_batch[start:start + split_size] - current_batch_split.append(current_segment) - current_batch_len.append(len(current_segment)) - - split_kv_idx_3d.append(current_batch_split) - split_kv_len_2d.append(current_batch_len) - - for time_idx in range(time): - current_time_merged = [] - for batch in split_kv_idx_3d: - current_time_merged.extend(batch[time_idx]) - merged_split_kv_idx_3d.append(current_time_merged) - - def reshape_kv_len_to_time_first(split_kv_len_2d): - if not split_kv_len_2d or not split_kv_len_2d[0]: - return [] - return [[batch_len[time_idx] for batch_len in split_kv_len_2d] - for time_idx in range(len(split_kv_len_2d[0]))] - - merged_split_kv_len_2d = reshape_kv_len_to_time_first(split_kv_len_2d) - return merged_split_kv_idx_3d, merged_split_kv_len_2d - def _generate_pcp_mtp_input( self, num_reqs: int,