From 380f089fbf81fcd3987899c741b6c4196cbce2a5 Mon Sep 17 00:00:00 2001 From: LICO67373 <110013619+LICO1314@users.noreply.github.com> Date: Wed, 7 Jan 2026 17:09:52 +0800 Subject: [PATCH] [Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (#4870) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What this PR does / why we need it? This PR fixes the `AttentionMaskBuilder` singleton initialization issue introduced in PR #4779 and removes the unused `pcp_prefill_mask` field. ### Background After PR #4779 made `AttentionMaskBuilder` a singleton with `@singleton` decorator, the class constructor now requires a `device` parameter. However, two initialization sites were still using the old parameterless constructor, causing failures. ### Changes 1. **Fix singleton initialization** - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendMLAMetadataBuilder.__init__()` - Fixed `AttentionMaskBuilder()` → `AttentionMaskBuilder(self.device)` in `AscendAttentionMetadataBuilder.__init__()` 2. **Remove unused field** - Removed `pcp_prefill_mask` field from `AscendPrefillContextParallelMetadata` (never used in codebase) - Updated related test assertions ### Related - Issue #5463 - PR #4779 (Unify all mask generation methods) - PR #5389 (Make AttentionMaskBuilder singleton) ## Does this PR introduce _any_ user-facing change? No. This is an internal refactoring. ## How was this patch tested? - ✅ Local testing: No linter errors - ✅ Unit tests for attention modules verified - ⏳ CI pipeline Signed-off-by: lico67373 <918688502@qq.com> Co-authored-by: weijinqian0 <1184188277@qq.com> --- tests/ut/attention/test_attention_v1.py | 3 +- tests/ut/attention/test_mla_cp.py | 2 - tests/ut/attention/test_mla_v1.py | 56 ++++++++++++++----- tests/ut/spec_decode/test_mtp_proposer.py | 4 -- tests/ut/worker/test_pcp_manager.py | 17 +----- vllm_ascend/attention/attention_mask.py | 19 ++++++- vllm_ascend/attention/attention_v1.py | 15 ++++- .../context_parallel/attention_cp.py | 7 +-- .../attention/context_parallel/common_cp.py | 1 - .../attention/context_parallel/mla_cp.py | 8 +-- vllm_ascend/attention/mla_v1.py | 20 ++++--- vllm_ascend/attention/sfa_v1.py | 5 +- vllm_ascend/attention/utils.py | 11 ---- vllm_ascend/compilation/acl_graph.py | 6 +- vllm_ascend/spec_decode/eagle_proposer.py | 6 -- vllm_ascend/spec_decode/mtp_proposer.py | 4 +- vllm_ascend/utils.py | 11 ++++ vllm_ascend/worker/model_runner_v1.py | 46 +-------------- vllm_ascend/worker/pcp_utils.py | 10 +--- vllm_ascend/worker/v2/attn_utils.py | 5 -- vllm_ascend/worker/v2/model_runner.py | 10 +--- 21 files changed, 118 insertions(+), 148 deletions(-) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 4b82320b..c5f0fc1b 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -53,6 +53,7 @@ class TestAscendAttentionMetadataBuilder(TestBase): self.mock_vllm_config = MagicMock() self.mock_vllm_config.speculative_config = None self.mock_vllm_config.model_config.max_model_len = 640 + self.mock_vllm_config.model_config.hf_text_config.sliding_window = None self.mock_vllm_config.cache_config.block_size = 64 self.mock_vllm_config.compilation_config.cudagraph_mode = None self.mock_vllm_config.scheduler_config.max_num_seqs = 10 @@ -89,8 +90,6 @@ class TestAscendAttentionMetadataBuilder(TestBase): slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1, 2]), positions=torch.tensor([10, 10]), - attn_mask=torch.ones((15, 15)), - spec_attn_mask=None, attn_state=AscendAttentionState.ChunkedPrefill, num_computed_tokens_cpu=None, seq_lens=None, diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 74d9ecbc..a8857d14 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -1004,8 +1004,6 @@ class TestAscendMLAImpl(TestBase): [chunk_seqlens, chunk_seqlens], dtype=torch.int32) attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens = kv_with_q_head_nomask_seqlens attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens = kv_with_q_tail_nomask_seqlens - attn_metadata.prefill.pcp_metadata.pcp_prefill_mask = torch.triu( - torch.ones(10, 10, dtype=torch.float16), 1) output = self.impl._forward_prefill(q_nope, q_pe, k_nope, k_pe, value, kv_c_and_k_pe_cache, diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index efbc3cdc..46a58626 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -244,8 +244,15 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.enable_chunked_prefill) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") + @patch('vllm_ascend.attention.attention_mask.get_pcp_group') + @patch('vllm.distributed.parallel_state.get_pcp_group') def test_ascend_mla_metadata_builder_build_full_graph( - self, mock_get_cos_and_sin_mla): + self, mock_get_pcp_group, mock_get_pcp_group_mask, + mock_get_cos_and_sin_mla): + pcp_group = MagicMock() + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group + mock_get_pcp_group_mask.return_value = pcp_group mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -400,14 +407,21 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.kv_cache_spec.num_heads = 32 @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") + @patch('vllm_ascend.attention.attention_mask.get_pcp_group') + @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros) @patch("torch.Tensor.npu", new=lambda self: self) @patch("torch.npu.is_available") def test_build_prefix_no_cache_metadata(self, mock_npu_available, - mock_zeros, + mock_zeros, mock_get_pcp_group, + mock_get_pcp_group_mask, mock_get_cos_and_sin_mla): mock_npu_available.return_value = False torch.Tensor.pin_memory = lambda x: x # noqa + pcp_group = MagicMock() + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group + mock_get_pcp_group_mask.return_value = pcp_group def zeros_override(*args, **kwargs): kwargs.pop('pin_memory', None) @@ -426,8 +440,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1]), positions=torch.tensor([10, 10]), - attn_mask=torch.ones((10, 10)), - spec_attn_mask=None, attn_state=AscendAttentionState.PrefillNoCache, num_computed_tokens_cpu=None, seq_lens=None, @@ -458,14 +470,21 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") + @patch('vllm_ascend.attention.attention_mask.get_pcp_group') + @patch('vllm.distributed.parallel_state.get_pcp_group') @patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros) @patch("torch.Tensor.npu", new=lambda self: self) @patch("torch.npu.is_available") def test_build_chunked_prefix_metadata(self, mock_npu_available, - mock_zeros, + mock_zeros, mock_get_pcp_group, + mock_get_pcp_group_mask, mock_get_cos_and_sin_mla): mock_npu_available.return_value = False torch.Tensor.pin_memory = lambda x: x # noqa + pcp_group = MagicMock() + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group + mock_get_pcp_group_mask.return_value = pcp_group def zeros_override(*args, **kwargs): kwargs.pop('pin_memory', None) @@ -485,8 +504,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1, 2]), positions=torch.tensor([10, 10]), - attn_mask=torch.ones((15, 15)), - spec_attn_mask=None, attn_state=AscendAttentionState.ChunkedPrefill, num_computed_tokens_cpu=None, seq_lens=None, @@ -517,8 +534,16 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") - def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla): + @patch('vllm_ascend.attention.attention_mask.get_pcp_group') + @patch('vllm.distributed.parallel_state.get_pcp_group') + def test_build_decode_only_metadata(self, mock_get_pcp_group, + mock_get_pcp_group_mask, + mock_get_cos_and_sin_mla): torch.Tensor.pin_memory = lambda x: x # noqa + pcp_group = MagicMock() + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group + mock_get_pcp_group_mask.return_value = pcp_group common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor([0, 1, 2, 3]), @@ -532,8 +557,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): actual_seq_lengths_q=torch.tensor([0, 1, 2]), decode_token_per_req=torch.tensor([1, 1, 1]), positions=torch.tensor([10, 10]), - attn_mask=torch.ones((3, 3)), - spec_attn_mask=None, attn_state=AscendAttentionState.DecodeOnly, num_computed_tokens_cpu=None, seq_lens=None, @@ -563,9 +586,16 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") - def test_build_for_graph_capture_decode_only(self, + @patch('vllm_ascend.attention.attention_mask.get_pcp_group') + @patch('vllm.distributed.parallel_state.get_pcp_group') + def test_build_for_graph_capture_decode_only(self, mock_get_pcp_group, + mock_get_pcp_group_mask, mock_get_cos_and_sin_mla): torch.Tensor.pin_memory = lambda x: x # noqa + pcp_group = MagicMock() + pcp_group.world_size = 1 + mock_get_pcp_group.return_value = pcp_group + mock_get_pcp_group_mask.return_value = pcp_group common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor([0, 1, 2, 3]), @@ -579,8 +609,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): actual_seq_lengths_q=torch.tensor([0, 1, 2]), decode_token_per_req=torch.tensor([1, 1, 1]), positions=torch.tensor([10, 10]), - attn_mask=torch.ones((3, 3)), - spec_attn_mask=None, attn_state=AscendAttentionState.DecodeOnly, num_computed_tokens_cpu=None, seq_lens=None, @@ -625,8 +653,6 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): slot_mapping=torch.tensor(range(20)), actual_seq_lengths_q=torch.tensor([0, 1]), positions=torch.tensor([10, 10]), - attn_mask=torch.ones((10, 10)), - spec_attn_mask=None, attn_state=AscendAttentionState.PrefillNoCache, num_computed_tokens_cpu=None, seq_lens=None, diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 918b6efb..307daacf 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -291,8 +291,6 @@ class TestMtpProposer: mock_runner = MagicMock() mock_runner.actual_seq_lengths_q = MagicMock() - mock_runner.attn_mask = MagicMock() - mock_runner.spec_attn_mask = MagicMock() mock_runner.attn_state = MagicMock() mock_runner.graph_pad_size = 0 mock_runner.decode_token_per_req = MagicMock() @@ -334,5 +332,3 @@ class TestMtpProposer: assert spec_common_attn_metadata.num_actual_tokens == total_num_tokens assert spec_common_attn_metadata.max_query_len == 8 assert spec_common_attn_metadata.actual_seq_lengths_q == proposer.runner.actual_seq_lengths_q - assert spec_common_attn_metadata.attn_mask == proposer.runner.attn_mask - assert spec_common_attn_metadata.spec_attn_mask == proposer.runner.spec_attn_mask diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index ca94c895..eaa34a2e 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -70,7 +70,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, input_batch.num_tokens = torch.tensor(num_tokens) query_lens = torch.tensor(query_lens) - result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None, + result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, input_batch) if not expect_not_none: @@ -97,21 +97,6 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, assert hasattr(result, 'head_attn_nomask_seqlens') assert hasattr(result, 'tail_attn_nomask_seqlens') - if hasattr(result, 'pcp_prefill_mask' - ) and result.pcp_prefill_mask is not None: - if use_mla: - assert result.pcp_prefill_mask.shape == (512, 512) - else: - assert result.pcp_prefill_mask.shape == (2048, 2048) - else: - if hasattr(result, 'pcp_prefill_mask'): - if result.pcp_prefill_mask is not None: - if use_mla: - assert result.pcp_prefill_mask.shape == (512, 512) - else: - assert result.pcp_prefill_mask.shape == (2048, - 2048) - @pytest.mark.parametrize( "tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens", diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 5bdfbd92..e7823b9e 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from vllm.distributed import get_pcp_group + +from vllm_ascend.platform import ModelConfig +from vllm_ascend.utils import singleton def _generate_attn_mask(max_seq_len, dtype): @@ -29,6 +33,7 @@ def _generate_attn_mask(max_seq_len, dtype): return attn_mask +@singleton class AttentionMaskBuilder: def __init__(self, device: torch.device): @@ -82,4 +87,16 @@ class AttentionMaskBuilder: triu_mask = torch.triu(mask, diagonal=1).to(self.device) tril_mask = torch.tril(mask, -sliding_window).to(self.device) self.swa_mask = triu_mask + tril_mask - return self.swa_mask \ No newline at end of file + return self.swa_mask + + def get_attention_mask(self, model_config: ModelConfig): + if model_config.runner_type == "pooling": + return self.get_attn_mask(2048, torch.bool) + + return self.get_splitfuse_attn_mask() + + def get_final_mla_mask(self, model_config: ModelConfig): + if get_pcp_group().world_size > 1: + return self.get_pcp_mla_mask(model_config.dtype) + # Prefill stages use 512x512 mask with appropriate dtype + return self.get_mla_mask(model_config.dtype) \ No newline at end of file diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 91051768..d19d3369 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport, from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.context_parallel.common_cp import ( AscendMetadataForDecode, AscendMetadataForPrefill) from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, @@ -219,6 +220,7 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): scheduler_config = vllm_config.scheduler_config self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill + self.attn_mask_builder = AttentionMaskBuilder(self.device) @classmethod def get_cudagraph_support( @@ -253,10 +255,19 @@ class AscendAttentionMetadataBuilder(AttentionMetadataBuilder[AscendMetadata]): seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens] - attn_mask = common_attn_metadata.attn_mask - swa_mask = common_attn_metadata.swa_mask attn_state = common_attn_metadata.attn_state + # Get attn_mask and swa_mask from singleton AttentionMaskBuilder + attn_mask = self.attn_mask_builder.get_attention_mask( + self.model_config) + + swa_mask = None + is_swa = hasattr(self.model_config.hf_text_config, 'sliding_window') + if self.model_config is not None and is_swa: + swa_mask = self.attn_mask_builder.get_swa_mask( + self.model_config.dtype, + self.model_config.hf_text_config.sliding_window) + # TODO: Yet another unnecessary H2D while we already have a query_start_loc on device query_start_loc = query_start_loc_cpu.pin_memory().to( self.device, non_blocking=True) diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 8d477909..affcf643 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -121,7 +121,8 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): slot_mapping = common_attn_metadata.slot_mapping[: num_actual_tokens_pcp_padded] - attn_mask = common_attn_metadata.attn_mask + attn_mask = self.attn_mask_builder.get_attention_mask( + self.model_config) attn_state = common_attn_metadata.attn_state num_computed_tokens_cpu = (seq_lens - query_lens) @@ -212,7 +213,6 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): head_attn_nomask_seqlens=head_attn_nomask_seqlens, tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, - pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask, pcp_allgather_restore_idx=common_long_seq_metadata. pcp_allgather_restore_idx) @@ -433,13 +433,12 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens \ if is_head else attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens - mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask output, lse = self._attention_with_nomask_and_mask( **data, q_seqlens=attn_mask_seqlens, kv_seqlens_nomask=nomask_seqlens, kv_seqlens_mask=attn_mask_seqlens, - mask=mask, + mask=attn_metadata.attn_mask, attn_metadata=attn_metadata) return output, lse diff --git a/vllm_ascend/attention/context_parallel/common_cp.py b/vllm_ascend/attention/context_parallel/common_cp.py index 0652a9a5..018919c0 100644 --- a/vllm_ascend/attention/context_parallel/common_cp.py +++ b/vllm_ascend/attention/context_parallel/common_cp.py @@ -21,7 +21,6 @@ class AscendPCPMetadata: head_attn_nomask_seqlens: torch.Tensor = None tail_attn_nomask_seqlens: torch.Tensor = None q_full_idx: torch.Tensor = None - pcp_prefill_mask: torch.Tensor = None pcp_allgather_restore_idx: Optional[list[int]] = None diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 6c2425c1..8b37765c 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -118,7 +118,6 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): tail_attn_nomask_seqlens=common_long_seq_metadata. tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, - pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask, pcp_allgather_restore_idx=common_long_seq_metadata. pcp_allgather_restore_idx) @@ -195,7 +194,7 @@ class AscendMlaCPMetadataBuilder(AscendMLAMetadataBuilder): ).item() if build_metadata_step == BUILD_METADATA_STEP_PREFILL: # For pcp + spec decode, we flatten seq_lens and block_table - # to avoid irregular spec_attn_mask shape + # to avoid irregular attn_mask shape return self.num_decodes_flatten + self.num_prefills else: return self.num_decodes_flatten @@ -420,7 +419,6 @@ class AscendMlaCPImpl(AscendMLAImpl): attn_mask_seqlens = attn_metadata.prefill.pcp_metadata.attn_mask_seqlens head_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.head_attn_nomask_seqlens tail_attn_nomask_seqlens = attn_metadata.prefill.pcp_metadata.tail_attn_nomask_seqlens - mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask output_head, lse_head = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_head_idx), q_pe=torch.index_select(q_pe, 0, q_head_idx), @@ -431,7 +429,7 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_nomask_idx=kv_with_q_head_nomask_idx, attn_mask_seqlens=attn_mask_seqlens, attn_nomask_seqlens=head_attn_nomask_seqlens, - mask=mask) + mask=attn_metadata.attn_mask) output_tail, lse_tail = self._attention_with_mask_and_nomask( q_nope=torch.index_select(q_nope, 0, q_tail_idx), @@ -443,7 +441,7 @@ class AscendMlaCPImpl(AscendMLAImpl): kv_nomask_idx=kv_with_q_tail_nomask_idx, attn_mask_seqlens=attn_mask_seqlens, attn_nomask_seqlens=tail_attn_nomask_seqlens, - mask=mask) + mask=attn_metadata.attn_mask) q_full_idx = attn_metadata.prefill.pcp_metadata.q_full_idx attn_output = torch.index_select( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index aa4ed077..04b59dd5 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -18,6 +18,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.context_parallel.common_cp import ( AscendPCPMetadata, CPChunkedContextMetadata) @@ -263,6 +264,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): self.graph_pad_size = 0 self.query_lens: torch.Tensor = None self.seq_lens: torch.Tensor = None + self.attn_mask_builder = AttentionMaskBuilder(self.device) @classmethod def get_cudagraph_support( @@ -448,7 +450,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): num_decodes=self.num_decodes, num_decode_tokens=self.num_decode_tokens, num_prefills=self.num_prefills, - attn_mask=common_attn_metadata.attn_mask, + attn_mask=self.attn_mask_builder.get_final_mla_mask( + self.model_config), attn_state=common_attn_metadata.attn_state, prefill=prefill_metadata, decode=decode_metadata, @@ -542,7 +545,8 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): prefill_input_positions = input_positions[tokens_start:] cos, sin = get_cos_and_sin_mla(prefill_input_positions) return AscendMLAPrefillMetadata( - attn_mask=common_attn_metadata.attn_mask, + attn_mask=self.attn_mask_builder.get_final_mla_mask( + self.model_config), query_lens=self.query_lens[reqs_start:].to(torch.int32), seq_lens=self.seq_lens, context_lens=self.seq_lens[reqs_start:], @@ -643,7 +647,7 @@ class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): seq_lens=self.seq_lens, seq_lens_list=seq_lens_list, max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, + attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(), actual_seq_lengths_q=actual_seq_lengths_q, sin=sin[:self.num_decode_tokens, ...], cos=cos[:self.num_decode_tokens, ...], @@ -1197,7 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl): # Output shape: [num_heads, num_tokens, dim] attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank) sparse_mode = 3 - spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore + attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q else: # The output layout is set to NBSD to eliminate the need for a @@ -1218,7 +1222,7 @@ class AscendMLAImpl(MLAAttentionImpl): attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank) sparse_mode = 0 - spec_attn_mask = None + attn_mask = None common_kwargs = { 'query_rope': q_pe, @@ -1226,7 +1230,7 @@ class AscendMLAImpl(MLAAttentionImpl): 'num_heads': self.num_heads, 'num_key_value_heads': self.num_kv_heads, 'input_layout': input_layout, - 'atten_mask': spec_attn_mask, + 'atten_mask': attn_mask, 'sparse_mode': sparse_mode, 'scale': self.scale, 'antiquant_mode': 0, @@ -1269,8 +1273,8 @@ class AscendMLAImpl(MLAAttentionImpl): (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), self.num_heads, self.num_kv_heads, input_layout, - weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None - else None, sparse_mode, self.scale, decode_meta.block_table, + weak_ref_tensors(attn_mask) if attn_mask is not None else + None, sparse_mode, self.scale, decode_meta.block_table, block_size, decode_meta.seq_lens_list, actual_seq_lengths, weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse))) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 119eef56..7430575d 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -19,6 +19,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, @@ -156,6 +157,7 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): and self.vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY ), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1." + self.attn_mask_builder = AttentionMaskBuilder(self.device) @classmethod def get_cudagraph_support( @@ -280,7 +282,8 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): seq_lens=seq_lens, slot_mapping=slot_mapping, head_dim=self.model_config.get_head_size(), - attn_mask=common_attn_metadata.attn_mask, + attn_mask=self.attn_mask_builder.get_attention_mask( + self.model_config), attn_state=common_attn_metadata.attn_state, block_tables=block_table, sin=sin[:num_input_tokens], diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index af353838..9168224c 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -66,8 +66,6 @@ class AscendPrefillContextParallelMetadata: q_full_idx: torch.Tensor = None - pcp_prefill_mask: torch.Tensor = None - # original query_lens before pcp split query_lens_pcp_full_cpu: torch.Tensor = None @@ -93,12 +91,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata): positions: torch.Tensor = None - attn_mask: torch.Tensor = None - - spec_attn_mask: torch.Tensor = None - - swa_mask: torch.Tensor = None - attn_state: Any = None graph_pad_size: int = -1 @@ -130,9 +122,6 @@ class AscendCommonAttentionMetadata(CommonAttentionMetadata): causal=self.causal, actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens], positions=self.positions[:num_actual_tokens], - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - swa_mask=self.swa_mask, attn_state=self.attn_state, graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. num_input_tokens=num_actual_tokens, diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 72cf925d..29ec5793 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -340,7 +340,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, graph_params.events[runtime_shape], ): (q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, - spec_attn_mask, sparse_mode, scale, block_table, block_size, + attn_mask, sparse_mode, scale, block_table, block_size, seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param seq_lens_list = forward_context.attn_metadata[ @@ -380,7 +380,7 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape, num_heads=num_heads, num_key_value_heads=num_kv_heads, input_layout=input_layout, - atten_mask=spec_attn_mask, + atten_mask=attn_mask, sparse_mode=sparse_mode, scale=scale, antiquant_mode=0, @@ -480,7 +480,7 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, seq_len = decode_meta.cp_seq_len # For pcp + spec decode, we flatten seq_lens - # to avoid irregular spec_attn_mask shape, + # to avoid irregular attn_mask shape, # so there's no need to divide runtime_shape by spec_multiple pad_length = runtime_shape - len(seq_len) pad_tensor = torch.zeros(pad_length, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 7fede206..d875d23e 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -222,8 +222,6 @@ class EagleProposer(VllmEagleProposer): slot_mapping=self.runner.input_batch.block_table[0]. slot_mapping.gpu, positions=self.runner.positions.gpu, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, max_seq_len=0, @@ -672,8 +670,6 @@ class EagleProposer(VllmEagleProposer): slot_mapping=common_attn_metadata.slot_mapping, actual_seq_lengths_q=self.runner.actual_seq_lengths_q, positions=common_attn_metadata.positions[token_indices], - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, max_seq_len=0) @@ -762,8 +758,6 @@ class EagleProposer(VllmEagleProposer): block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping, positions=common_attn_metadata.positions, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, num_computed_tokens_cpu=common_attn_metadata. diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 6f68b488..e95e5f2f 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union import torch import torch.nn as nn @@ -73,8 +73,6 @@ class MtpProposer(EagleProposer): slot_mapping=self.runner.input_batch.block_table[0]. slot_mapping.gpu, positions=self.runner.positions.gpu, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, attn_state=self.runner.attn_state, decode_token_per_req=self.runner.decode_token_per_req, max_seq_len=0) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 3527b910..80565554 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1150,3 +1150,14 @@ def check_kv_extra_config(vllm_config): _check( "decode", vllm_config.kv_transfer_config.get_from_extra_config("decode", {})) + + +def singleton(cls): + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return get_instance \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 70f60012..e77fb706 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -77,7 +77,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput from vllm.v1.worker.utils import AttentionGroup from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata # yapf conflicts with isort for this block @@ -230,7 +229,6 @@ class NPUModelRunner(GPUModelRunner): self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) self.sampler = AscendSampler() - self.attn_mask = None self.attn_state = None # Ascend-specific configurations @@ -264,19 +262,9 @@ class NPUModelRunner(GPUModelRunner): use_sparse=self.use_sparse, use_mm_prefix=self.model_config is not None and self.model_config.is_mm_prefix_lm) - self.attn_mask_builder = AttentionMaskBuilder(self.device) self._set_up_drafter() - # sliding window attn mask - self.swa_mask = None - is_swa = hasattr(self.vllm_config.model_config.hf_text_config, - "sliding_window") - if self.model_config is not None and is_swa: - self.swa_mask = self.attn_mask_builder.get_swa_mask( - self.dtype, - self.vllm_config.model_config.hf_text_config.sliding_window) - # kv role self.is_kv_producer = False self.is_kv_consumer = False @@ -370,7 +358,6 @@ class NPUModelRunner(GPUModelRunner): def _set_up_drafter(self): # Set up speculative decoding. - self.spec_attn_mask = None self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, SuffixDecodingProposer]] = None self.actual_seq_lengths_q: list[int] = [] @@ -379,8 +366,6 @@ class NPUModelRunner(GPUModelRunner): spec_token_num = self.speculative_config.num_speculative_tokens assert spec_token_num > 0 self.decode_token_per_req = 1 + spec_token_num - self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( - ) if get_pp_group().is_last_rank: self.drafter = self._get_drafter() if self.speculative_config.method == "eagle3": @@ -494,22 +479,6 @@ class NPUModelRunner(GPUModelRunner): return self.model.unwrap() return self.model - def _make_attention_mask(self, attn_state) -> torch.Tensor: - # pcp situation. - if self.attn_mask_builder is None: - raise ValueError("Attn mask builder is None") - # Pooling situation. - if self.model_config.runner_type == "pooling": - return self.attn_mask_builder.get_attn_mask(2048, torch.bool) - - if self.vllm_config.model_config.use_mla: - if self.pcp_size > 1: - return self.attn_mask_builder.get_pcp_mla_mask(self.dtype) - # mla prefill - if attn_state != AscendAttentionState.DecodeOnly: - return self.attn_mask_builder.get_mla_mask(self.dtype) - return self.attn_mask_builder.get_splitfuse_attn_mask() - def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -551,7 +520,6 @@ class NPUModelRunner(GPUModelRunner): with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] - self.attn_mask = self._make_attention_mask(attn_state) # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] @@ -941,7 +909,7 @@ class NPUModelRunner(GPUModelRunner): if self.pcp_size * self.dcp_size > 1: self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( total_num_scheduled_tokens, self.query_lens, - self.attn_mask, self.input_batch) + self.input_batch) blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1) if self.pcp_size > 1: slot_mapping = self.pcp_manager.get_padded_slot_mapping( @@ -997,9 +965,6 @@ class NPUModelRunner(GPUModelRunner): num_computed_tokens_cpu=self.input_batch. num_computed_tokens_cpu_tensor[:num_reqs], positions=self.positions.gpu, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - swa_mask=self.swa_mask, attn_state=self.attn_state, max_query_len=max_num_scheduled_tokens, decode_token_per_req=self.decode_token_per_req, @@ -1009,7 +974,7 @@ class NPUModelRunner(GPUModelRunner): if self.speculative_config and self.pcp_size * self.dcp_size > 1: # For pcp + spec decode, we flatten block_table - # to avoid irregular spec_attn_mask shape, e.g., + # to avoid irregular attn_mask shape, e.g., # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, # ori block_table: # [d0, d1, p0, p1, p2] # (num_reqs_d + num_reqs_p, max_num_blocks), @@ -1918,7 +1883,6 @@ class NPUModelRunner(GPUModelRunner): self.query_start_loc.cpu[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens) - self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -1930,8 +1894,7 @@ class NPUModelRunner(GPUModelRunner): slot_mapping = self.input_batch.block_table[ kv_cache_group_id].slot_mapping long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata( - num_tokens, self.query_lens, self.attn_mask, - self.input_batch) + num_tokens, self.query_lens, self.input_batch) if long_seq_metadata is not None: pcp_world_size = get_pcp_group().world_size dcp_world_size = get_dcp_group().world_size @@ -1954,9 +1917,6 @@ class NPUModelRunner(GPUModelRunner): slot_mapping=slot_mapping.gpu, num_computed_tokens_cpu=num_computed_tokens_cpu, positions=self.positions.gpu, - attn_mask=self.attn_mask, - spec_attn_mask=self.spec_attn_mask, - swa_mask=self.swa_mask, attn_state=self.attn_state, max_query_len=max_query_len, decode_token_per_req=self.decode_token_per_req, diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index cded9f54..f0ace8a4 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -498,7 +498,7 @@ class PCPManager: torch.float32).argsort().to(torch.int32) def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, - attn_mask, input_batch): + input_batch): from vllm_ascend.attention.utils import \ AscendPrefillContextParallelMetadata num_reqs = input_batch.num_reqs or query_lens.size(0) @@ -523,7 +523,7 @@ class PCPManager: dtype=torch.int32, ) # For pcp + spec decode, we flatten seq_lens - # to avoid irregular spec_attn_mask shape. + # to avoid irregular attn_mask shape. # Same as block_table, we flatten decode seq_lens to query_lens, # and keep prefill seq_lens unchanged. for decode_idx in range(self.decode_threshold): @@ -657,13 +657,11 @@ class PCPManager: split_with_q_head_nomask_idx_reqs, split_kv_with_q_tail_nomask_idx_reqs, head_attn_nomask_seqlens, chunk_seqlens) - pcp_prefill_mask = attn_mask self.extra_long_seq_kwargs = { 'attn_mask_seqlens': attn_mask_seqlens, 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, - 'pcp_prefill_mask': pcp_prefill_mask + 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens } long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: num_actual_tokens_pcp_padded] @@ -685,8 +683,6 @@ class PCPManager: 'head_attn_nomask_seqlens'] long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ 'tail_attn_nomask_seqlens'] - long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ - 'pcp_prefill_mask'] if self.vllm_config.model_config.use_mla: long_seq_metadata.kv_with_q_head_nomask_idx_tensor = split_q_head_nomask_idx_tensor_list long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list diff --git a/vllm_ascend/worker/v2/attn_utils.py b/vllm_ascend/worker/v2/attn_utils.py index 655c0369..738a84c3 100644 --- a/vllm_ascend/worker/v2/attn_utils.py +++ b/vllm_ascend/worker/v2/attn_utils.py @@ -58,9 +58,6 @@ def build_attn_metadata( decode_token_per_req: int, actual_seq_lengths_q: list[int], positions: torch.Tensor | None = None, - attn_mask: torch.Tensor - | None = None, - spec_attn_mask: torch.Tensor | None = None, attn_state: Any | None = None, graph_pad_size: int = -1, num_input_tokens: int = 0, @@ -92,8 +89,6 @@ def build_attn_metadata( slot_mapping=slot_mapping, actual_seq_lengths_q=actual_seq_lengths_q, positions=positions, - attn_mask=attn_mask, - spec_attn_mask=spec_attn_mask, attn_state=attn_state, graph_pad_size=graph_pad_size, num_input_tokens=num_input_tokens, diff --git a/vllm_ascend/worker/v2/model_runner.py b/vllm_ascend/worker/v2/model_runner.py index 447fdf8d..99987c5d 100644 --- a/vllm_ascend/worker/v2/model_runner.py +++ b/vllm_ascend/worker/v2/model_runner.py @@ -32,8 +32,7 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata, - build_attn_state, - make_attention_mask) + build_attn_state) from vllm_ascend.worker.v2.input_batch import AscendInputBuffers from vllm_ascend.worker.v2.sample.sampler import AscendSampler from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper @@ -155,12 +154,6 @@ class NPUModelRunner(GPUModelRunner): num_scheduled_tokens, num_valid_tokens, ) - attn_mask = make_attention_mask( - self.vllm_config, - attn_state, - self.dtype, - self.device, - ) idx_mapping_list = [ self.req_states.req_id_to_index[req_id] for req_id in req_ids @@ -284,7 +277,6 @@ class NPUModelRunner(GPUModelRunner): slot_mappings=slot_mappings.to(torch.int32), kv_cache_config=self.kv_cache_config, decode_token_per_req=self.decode_token_per_req, - attn_mask=attn_mask, attn_state=attn_state, )