diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py b/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py index 61d99c97..7df19de4 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_accuracy.py @@ -210,3 +210,72 @@ def test_accuracy_pcp_only(max_tokens: int, ) -> None: name_0="vllm_eager_outputs", name_1="vllm_pcp_only_outputs", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [10]) +def test_models_long_sequence_cp_kv_interleave_size_output_between_tp_and_cp( + model: str, + max_tokens: int, +) -> None: + prompts = [ + "The president of the United States is", "The capital of France is" + ] + + common_kwargs = { + "max_model_len": 1024, + } + + if model == "vllm-ascend/DeepSeek-V2-Lite-W8A8": + cp_kwargs = { + "tensor_parallel_size": 2, + "decode_context_parallel_size": 2, + "prefill_context_parallel_size": 2, + "enable_expert_parallel": True, + "cp_kv_cache_interleave_size": 128, + "enforce_eager": True, + "quantization": "ascend", + } + tp_kwargs = { + "tensor_parallel_size": 4, + "enable_expert_parallel": True, + "enforce_eager": True, + "quantization": "ascend", + } + + else: + cp_kwargs = { + "tensor_parallel_size": 1, + "decode_context_parallel_size": 1, + "prefill_context_parallel_size": 2, + "cp_kv_cache_interleave_size": 128, + "compilation_config": { + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + }, + } + tp_kwargs = { + "tensor_parallel_size": 2, + "enforce_eager": True, + } + + cp_full_kwargs = {} + cp_full_kwargs.update(common_kwargs) # type: ignore + cp_full_kwargs.update(cp_kwargs) # type: ignore + + tp_full_kwargs = {} + tp_full_kwargs.update(common_kwargs) # type: ignore + tp_full_kwargs.update(tp_kwargs) # type: ignore + with VllmRunner(model, **cp_full_kwargs) as runner: # type: ignore + vllm_context_parallel_outputs = runner.generate_greedy( + prompts, max_tokens) + + with VllmRunner(model, **tp_full_kwargs) as runner: # type: ignore + vllm_eager_outputs = runner.generate_greedy(prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs, + outputs_1_lst=vllm_context_parallel_outputs, + name_0="vllm_eager_outputs", + name_1="vllm_context_parallel_outputs", + ) \ No newline at end of file diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index f02c1b88..801e42c9 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -439,11 +439,7 @@ class TestAscendMLAImpl(TestBase): decode_metadata = MagicMock() decode_metadata.actual_seq_lengths_q = MagicMock() decode_metadata.seq_lens_list = MagicMock() - decode_metadata.batch_seq_mask = torch.tensor([True, False], - dtype=torch.bool) - - result = _process_attn_out_lse(attn_output, softmax_lse, - decode_metadata.batch_seq_mask) + result = _process_attn_out_lse(attn_output, softmax_lse) self.assertEqual(result.shape[0], B * self.impl.pcp_size) self.assertEqual(result.shape[1], N) @@ -478,8 +474,6 @@ class TestAscendMLAImpl(TestBase): attn_metadata.decode = MagicMock() attn_metadata.decode.actual_seq_lengths_q = MagicMock() attn_metadata.decode.seq_lens_list = MagicMock() - attn_metadata.decode.batch_seq_mask = torch.tensor([False, False], - dtype=torch.bool) self.impl.enable_kv_nz = True @@ -886,12 +880,9 @@ class TestAscendMLAImpl(TestBase): # Inputs attn_output = torch.randn(B, H, D) softmax_lse = torch.randn(B, H, 1) - batch_seq_mask = torch.tensor([False, True, False, False]) # [B] decode_meta = MagicMock() - decode_meta.batch_seq_mask = batch_seq_mask - result = _process_attn_out_lse(attn_output, softmax_lse, - batch_seq_mask) + result = _process_attn_out_lse(attn_output, softmax_lse) # [PCP * S, DCP * H, D + 1] self.assertIsInstance(result, torch.Tensor) assert result.shape == (B * self.impl.pcp_size, H, D + 1) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index d82218e0..6059ecab 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -137,7 +137,6 @@ class TestAscendMLADecodeMetadata(TestBase): seq_lens_list = [2, 3] attn_mask = None cp_seq_len = torch.tensor([2, 3]) - batch_seq_mask = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 0]]) metadata = AscendMLADecodeMetadata(input_positions=input_positions, block_table=block_table, @@ -145,8 +144,7 @@ class TestAscendMLADecodeMetadata(TestBase): max_seq_lens=max_seq_lens, seq_lens_list=seq_lens_list, attn_mask=attn_mask, - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + cp_seq_len=cp_seq_len) self.assertIs(metadata.input_positions, input_positions) self.assertIs(metadata.block_table, block_table) @@ -155,7 +153,6 @@ class TestAscendMLADecodeMetadata(TestBase): self.assertEqual(metadata.seq_lens_list, seq_lens_list) self.assertIsNone(attn_mask) self.assertIs(metadata.cp_seq_len, cp_seq_len) - self.assertIs(metadata.batch_seq_mask, batch_seq_mask) class TestAscendMLAMetadata(TestBase): diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index cae53590..ae406aa9 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -73,9 +73,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.batch_seq_mask_buf = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, dtype=torch.uint8, device=device - ) self.pcp_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() @@ -216,14 +213,9 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): if num_decodes > 0: num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp) num_computed_tokens_array = num_computed_tokens_array[:num_decodes] - batch_seq_mask = num_computed_tokens_array[:, self.pcp_rank, self.dcp_rank] == 0 # TODO: numpy array mode of the shared memory is used to improve performance - self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_( - torch.from_numpy(batch_seq_mask), non_blocking=True - ) decode_metadata = AscendMetadataForDecode( num_computed_tokens_of_pcp_dcp=num_computed_tokens_array, - batch_seq_mask=self.batch_seq_mask_buf[: batch_seq_mask.shape[0]], block_tables=block_table[:num_decodes], ) @@ -525,7 +517,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): graph_params.handles[num_tokens].append(handle) else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(query, k_nope, value, **common_kwargs) - attn_out_lse = _process_attn_out_lse(attn_out, attn_lse, attn_metadata.decode_meta.batch_seq_mask) + attn_out_lse = _process_attn_out_lse(attn_out, attn_lse) attn_out = _npu_attention_update(self.head_size, attn_out_lse) return attn_out @@ -633,9 +625,6 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): actual_seq_lengths_kv=prefill_metadata.chunked_context.actual_seq_lengths_kv, actual_seq_lengths=attn_metadata.prefill.chunked_context.actual_chunk_seq_lengths, ) - batch_chunk_seq_mask = attn_metadata.prefill.chunked_context.batch_chunk_seq_mask - lse_mask = batch_chunk_seq_mask[:, None, None].expand_as(prefix_chunk_lse) - prefix_chunk_lse = torch.where(lse_mask, -torch.inf, prefix_chunk_lse) return prefix_chunk_output, prefix_chunk_lse diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 65103059..03038b25 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -84,20 +84,13 @@ class AscendMetadataForDecode: """Decode-specific metadata for Ascend attention with Context Parallelism.""" num_computed_tokens_of_pcp_dcp: list[list[list[int]]] | None = None - batch_seq_mask: torch.Tensor = None block_tables: torch.Tensor = None -def _process_attn_out_lse( - attn_output: torch.Tensor, softmax_lse: torch.Tensor, batch_seq_mask: torch.Tensor -) -> torch.Tensor: +def _process_attn_out_lse(attn_output: torch.Tensor, softmax_lse: torch.Tensor) -> torch.Tensor: pcp_size = get_pcp_group().world_size dcp_size = get_decode_context_model_parallel_world_size() dcp_group = get_dcp_group().device_group if dcp_size > 1 else None - out_mask = batch_seq_mask[:, None, None].expand_as(attn_output) - attn_output = torch.where(out_mask, 0, attn_output) - lse_mask = batch_seq_mask[:, None, None].expand_as(softmax_lse) - softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse) softmax_lse = softmax_lse.to(torch.float32) attn_output = attn_output.to(torch.float32) # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 81298f94..8b817503 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -68,10 +68,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.cp_local_block_size = vllm_config.parallel_config.cp_kv_cache_interleave_size self.cp_virtual_block_size = self.cp_local_block_size * self.dcp_size * self.pcp_size - scheduler_config = vllm_config.scheduler_config - decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0) - max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) - self.batch_seq_mask_buf = torch.empty(max_num_seqs * self.decode_threshold, dtype=torch.uint8, device=device) self.block_size = (self.block_size * self.cp_virtual_block_size) // np.gcd( self.block_size, self.cp_virtual_block_size ) @@ -238,12 +234,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): cp_seq_len = num_computed_tokens_of_cp_dcp_array[:, self.pcp_rank, self.dcp_rank] cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32) - batch_seq_mask = cp_seq_len == 0 - self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_(batch_seq_mask, non_blocking=True) - batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]] - cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) decode_metadata.cp_seq_len = cp_seq_len.tolist() - decode_metadata.batch_seq_mask = batch_seq_mask actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1 decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q @@ -651,7 +642,7 @@ class AscendMlaCPImpl(AscendMLAImpl): softmax_lse = softmax_lse.permute(0, 2, 1, 3).reshape(B_lse * Q_S, N_lse, 1) # Update out&lse - attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, decode_meta.batch_seq_mask) + attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse) attn_output = _npu_attention_update(self.kv_lora_rank, attn_out_lse) return self._v_up_proj(attn_output) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 9383a165..1ec14c94 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -134,7 +134,6 @@ class AscendMLADecodeMetadata: sin: torch.Tensor = None cos: torch.Tensor = None cp_seq_len: torch.Tensor = None - batch_seq_mask: torch.Tensor = None @dataclass @@ -577,7 +576,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.block_table = self.block_table[:self.graph_pad_size, ...] seq_lens_list = self.seq_lens.tolist() - cp_seq_len, batch_seq_mask = None, None + cp_seq_len = None if self.graph_pad_size > num_reqs: if self.speculative_config.disable_padded_drafter_batch: @@ -638,8 +637,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): actual_seq_lengths_q=actual_seq_lengths_q, sin=sin[:self.num_decode_tokens, ...], cos=cos[:self.num_decode_tokens, ...], - cp_seq_len=cp_seq_len, - batch_seq_mask=batch_seq_mask) + cp_seq_len=cp_seq_len) return decode_metadata def build_for_graph_capture( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index dac8b5d1..05f03389 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -477,7 +477,7 @@ class MtpProposer(EagleProposer): self.positions[:batch_size] = clamped_positions self.hidden_states[:hidden_states.shape[0]] = hidden_states if self.pcp_size * self.dcp_size > 1: - # update local seq_len and batch_seq_mask + # update local seq_len num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( ori_seq_len + step + 1, self.pcp_size, @@ -486,14 +486,7 @@ class MtpProposer(EagleProposer): ) cp_seq_len = \ num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] - batch_seq_mask = (cp_seq_len == 0) - builder.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_( - batch_seq_mask, non_blocking=True) - batch_seq_mask = builder.batch_seq_mask_buf[:batch_seq_mask. - shape[0]] - cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len) attn_metadata_i.decode.cp_seq_len = cp_seq_len - attn_metadata_i.decode.batch_seq_mask = batch_seq_mask # update slot_mapping slot_indices += self.pcp_size slot_mapping = mtp_slot_mapping[slot_indices]