From c331503677d6c57a2387ed7c5c58a4ce956674d9 Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Tue, 9 Dec 2025 18:51:00 +0800 Subject: [PATCH] [Refactor] 2/N Unify all mask generation methods and cache mask (#4779) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan Signed-off-by: weijinqian_v1 Signed-off-by: ZYang6263 Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun Signed-off-by: lulina Signed-off-by: zengran Signed-off-by: shiro-zzzz Signed-off-by: dependabot[bot] Signed-off-by: 李少鹏 Signed-off-by: xuyexiong Signed-off-by: MengqingCao Signed-off-by: lhp-deep Signed-off-by: gcanlin Signed-off-by: wangli Co-authored-by: wangxiyuan Co-authored-by: Mengqing Cao Co-authored-by: weijinqian_v1 Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina Co-authored-by: zengzengran Co-authored-by: shiro-zzzz Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 Co-authored-by: xuyexiong Co-authored-by: lhp-deep Co-authored-by: Canlin Guo Co-authored-by: Li Wang --- tests/ut/attention/test_attention_mask.py | 60 +++------------- vllm_ascend/attention/attention_mask.py | 88 ++++++++++------------- vllm_ascend/attention/mla_v1.py | 21 +----- vllm_ascend/attention/utils.py | 2 - vllm_ascend/spec_decode/eagle_proposer.py | 8 +-- vllm_ascend/worker/model_runner_v1.py | 61 ++++------------ 6 files changed, 66 insertions(+), 174 deletions(-) diff --git a/tests/ut/attention/test_attention_mask.py b/tests/ut/attention/test_attention_mask.py index 9bd4cd0e..cabffda4 100644 --- a/tests/ut/attention/test_attention_mask.py +++ b/tests/ut/attention/test_attention_mask.py @@ -21,58 +21,23 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder class TestAttentionMaskBuilder(TestBase): - def test_init_attention_mask_builder(self): - # generate attention_mask_builder with float16 - attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, - dtype=torch.float16) - self.assertEqual(attention_mask_builder._seq_len_cached, 1024) - self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, - torch.float16) - self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (1024, 1024)) - self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], - torch.tensor(float("-inf"), dtype=torch.float16)) - - # generate attention_mask_builder with bfloat16 - attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048, - dtype=torch.bfloat16) - self.assertEqual(attention_mask_builder._seq_len_cached, 2048) - self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, - torch.bfloat16) - self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (2048, 2048)) - self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], - torch.tensor(1, dtype=torch.bfloat16)) - - def test_get_mask_scale_factor(self): - # supported data types - self.assertEqual( - AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1) - self.assertEqual( - AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000) - # mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16 - # Otherwise raise ValueError - with self.assertRaises(ValueError): - AttentionMaskBuilder.get_mask_scale_factor(torch.int8) - def test_get_attn_mask(self): # if the len is less than max_seq_len, the attn_mask_cache will not be updated - attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, - dtype=torch.float16) - attn_mask = attention_mask_builder.get_attn_mask( - max_seq_len=512, dtype=torch.float16, device=torch.device("cpu")) + attention_mask_builder = AttentionMaskBuilder(torch.device("cpu")) + attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=512, + dtype=torch.float16) self.assertEqual(attn_mask.shape, (512, 512)) self.assertEqual(attn_mask[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) - self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + self.assertEqual(attention_mask_builder._seq_len_cached, 512) self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (1024, 1024)) + (512, 512)) self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) # if the len is greater than max_seq_len, the attn_mask_cache will be updated - attn_mask = attention_mask_builder.get_attn_mask( - max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu")) + attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=2048, + dtype=torch.float16) self.assertEqual(attn_mask.shape, (2048, 2048)) self.assertEqual(attn_mask[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) @@ -83,13 +48,6 @@ class TestAttentionMaskBuilder(TestBase): torch.tensor(float("-inf"), dtype=torch.float16)) def test_get_splitfuse_attn_mask(self): - attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, - dtype=torch.float16) - attn_mask = attention_mask_builder.get_splitfuse_attn_mask( - seq_lens=torch.tensor([10, 20, 100]), - position=torch.tensor([7, 8, 9, 18, 19, 99]), - dtype=torch.float16, - device=torch.device("cpu"), - ) + attention_mask_builder = AttentionMaskBuilder(torch.device("cpu")) + attn_mask = attention_mask_builder.get_splitfuse_attn_mask() self.assertEqual(attn_mask.shape, (2048, 2048)) - self.assertEqual(attention_mask_builder._seq_len_cached, 1024) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 2c963b5c..c1322351 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -31,66 +31,54 @@ def _generate_attn_mask(max_seq_len, dtype): class AttentionMaskBuilder: - def __init__( - self, - max_seq_len: int, - dtype: torch.dtype, - device: torch.device = None, - ): - # NOTE: The device argument specifies the target NPU - # to be used for the newly added FIA operator. - # Only pass this parameter when using the new FIA operator. - - attn_mask = _generate_attn_mask(max_seq_len, dtype) - - self._seq_len_cached = attn_mask.shape[0] - self.attn_mask_cache = attn_mask + def __init__(self, device: torch.device): + self.attn_mask_cache = None + self._seq_len_cached = 0 self.device = device self.pooling_mask = None - assigned_mask_dim = 2048 - self.chunked_prefill_attn_mask = torch.triu( - torch.ones(assigned_mask_dim, assigned_mask_dim), - diagonal=1).to(torch.int8).to(device) + self.mla_mask = None + self.chunked_prefill_attn_mask = None + self.pcp_mla_mask = None - @staticmethod - def get_mask_scale_factor(dtype: torch.dtype = torch.float16): - if dtype == torch.float16: - mask_scale_factor = 1 - elif dtype == torch.bfloat16: - mask_scale_factor = -10000 - else: - raise ValueError( - "The current operation now only supports data types: torch.float16 and " - "torch.bfloat16. Please ensure the input is of one of these types." - ) - return mask_scale_factor - - def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, - device: torch.device): - self._update_attn_cache(max_seq_len, dtype) + def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype): + if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached: + self.attn_mask_cache = _generate_attn_mask(max_seq_len, dtype) + self._seq_len_cached = max_seq_len + assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask." + if self.attn_mask_cache.dtype != dtype: + self.attn_mask_cache = self.attn_mask_cache.to(dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( - ).to(device, non_blocking=True) + ).to(self.device, non_blocking=True) - def get_pooling_mask(self, device): + def get_pooling_mask(self): if self.pooling_mask is None: # the compressed attention mask for npu_fusion_attention sparse mode 4 self.pooling_mask = torch.triu(torch.ones( - 2048, 2048), diagonal=1).to(torch.bool).to(device, + 2048, 2048), diagonal=1).to(torch.bool).to(self.device, non_blocking=True) return self.pooling_mask - def get_splitfuse_attn_mask( - self, - seq_lens: torch.Tensor = None, - position: torch.Tensor = None, - dtype: torch.dtype = None, - device: torch.device = None, - ) -> torch.Tensor: + def get_splitfuse_attn_mask(self) -> torch.Tensor: + if self.chunked_prefill_attn_mask is None: + self.chunked_prefill_attn_mask = torch.triu( + torch.ones(2048, + 2048), diagonal=1).to(torch.int8).to(self.device) return self.chunked_prefill_attn_mask - def _update_attn_cache(self, seqlen: int, dtype: torch.dtype): - if seqlen > self._seq_len_cached: - self._seq_len_cached = seqlen - self.attn_mask_cache = _generate_attn_mask(seqlen, dtype) - if self.attn_mask_cache.dtype != dtype: - self.attn_mask_cache = self.attn_mask_cache.to(dtype) + def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor: + if self.mla_mask is None or self.mla_mask.dtype != dtype: + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=dtype), 1) + self.mla_mask = torch.where(prefill_mask == 1, mask_value, + 0).to(dtype) + return self.mla_mask + + def get_pcp_mla_mask(self, dtype: torch.dtype): + if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype: + self.pcp_mla_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=dtype), 1) + return self.pcp_mla_mask diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index cd2111fa..a6a6447c 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -202,7 +202,6 @@ class AscendMLAMetadataBuilder: understand this class """ - # _attn_mask_builder = None def __init__(self, kv_cache_spec, layer_names, @@ -862,7 +861,6 @@ class AscendMLAImpl(MLAAttentionImpl): vllm_config = get_current_vllm_config() self.ring_mla_mask_size = 512 - self.prefill_mask = None self.speculative_config = vllm_config.speculative_config self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO @@ -1167,10 +1165,7 @@ class AscendMLAImpl(MLAAttentionImpl): .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - if self.pcp_size > 1: - mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask - else: - mask = self.prefill_mask + mask = attn_metadata.attn_mask torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe, @@ -1214,24 +1209,12 @@ class AscendMLAImpl(MLAAttentionImpl): num_tokens, dtype=torch.float32, device=q_nope.device) - if self.prefill_mask is None: - if q_nope.dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 - prefill_mask = torch.triu( - torch.ones(self.ring_mla_mask_size, - self.ring_mla_mask_size, - device=q_nope.device, - dtype=q_nope.dtype), 1) - self.prefill_mask = torch.where(prefill_mask == 1, mask_value, - 0).to(q_nope.dtype) torch_npu.atb.npu_ring_mla(q_nope=q_nope, q_rope=q_pe, k_nope=k_nope, k_rope=k_pe, value=value, - mask=self.prefill_mask, + mask=attn_metadata.attn_mask, seqlen=attn_metadata.prefill.query_lens, head_num=self.num_heads, kv_head_num=self.num_heads, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index e929dacc..a2f71de7 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -88,8 +88,6 @@ class AscendCommonAttentionMetadata: attn_mask: torch.Tensor = None - fia_attn_mask: torch.Tensor = None - spec_attn_mask: torch.Tensor = None attn_state: Any = None diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index af7d3689..27a7f717 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -77,9 +77,7 @@ class EagleProposer(Proposer): 1, device=device, dtype=torch.int32) - attn_mask_len = self.vllm_config.model_config.max_model_len - self.attn_mask_builder = AttentionMaskBuilder( - attn_mask_len, self.vllm_config.model_config.dtype, device=device) + self.attn_mask_builder = AttentionMaskBuilder(self.device) def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( @@ -570,9 +568,7 @@ class EagleProposer(Proposer): self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states - attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( - attn_metadata.seq_lens, positions_cpu, - self.vllm_config.model_config.dtype, self.device) + attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() attn_metadata.attn_mask = attn_mask attn_metadata.block_tables = block_table.to(device) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 56cef32c..f6fded7d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -378,12 +378,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.block_size, use_mla=self.model_config.use_mla, use_sparse=self.use_sparse) - if self.pcp_size > 1: - self.attn_mask_builder = None - else: - self.attn_mask_builder = AttentionMaskBuilder( - self.scheduler_config.max_num_batched_tokens, self.dtype, - self.device) + self.attn_mask_builder = AttentionMaskBuilder(self.device) self._set_up_drafter() @@ -651,10 +646,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): spec_token_num = self.speculative_config.num_speculative_tokens assert spec_token_num > 0 self.decode_token_per_req = 1 + spec_token_num - self.spec_attn_mask = torch.triu(torch.ones(2048, - 2048, - dtype=torch.bool), - diagonal=1).to(self.device) + self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( + ) if get_pp_group().is_last_rank: self.drafter = self._get_drafter() self.rejection_sampler = AscendRejectionSampler(self.sampler) @@ -1033,21 +1026,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): return tuple(tasks) - def _make_attention_mask(self, seq_lens, position, - attn_state) -> torch.Tensor: + def _make_attention_mask(self, attn_state) -> torch.Tensor: # pcp situation. - if self.pcp_size > 1: - return None if self.attn_mask_builder is None: raise ValueError("Attn mask builder is None") - # dcp situation. - if self.dcp_size > 1: - return self.attn_mask_builder.get_splitfuse_attn_mask() - if self.vllm_config.model_config.use_mla: - return None # Pooling situation. if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": - return self.attn_mask_builder.get_pooling_mask(self.device) + return self.attn_mask_builder.get_pooling_mask() + + if self.vllm_config.model_config.use_mla: + if self.pcp_size > 1: + return self.attn_mask_builder.get_pcp_mla_mask(self.dtype) + # mla prefill + if attn_state != AscendAttentionState.DecodeOnly: + return self.attn_mask_builder.get_mla_mask(self.dtype) return self.attn_mask_builder.get_splitfuse_attn_mask() def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): @@ -1668,16 +1660,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.positions[:num_input_tokens].copy_( self.positions_cpu[:num_input_tokens], non_blocking=True) - # Make Attention metadata - positions_cpu = self.positions_cpu[:num_input_tokens] - positions = self.positions[:num_input_tokens] - seq_lens_cpu = self.seq_lens_cpu[:num_reqs] - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) - self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, - position=positions_cpu, - attn_state=attn_state) + self.attn_mask = self._make_attention_mask(attn_state) self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill @@ -2840,12 +2825,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.query_start_loc_cpu[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens) - - assigned_mask_dim = 2048 - self.attn_mask = torch.triu(torch.ones(assigned_mask_dim, - assigned_mask_dim), - diagonal=1).to(torch.int8).to( - self.device) + self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -4499,18 +4479,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): tail_attn_nomask_seqlens = torch.tensor( [chunk_seqlens, kv_with_q_tail_nomask_seqlens], dtype=torch.int32) - if self.vllm_config.model_config.use_mla: - pcp_prefill_mask = torch.triu( - torch.ones(512, - 512, - device=self.device, - dtype=self.dtype), 1) - else: - pcp_prefill_mask = torch.triu( - torch.full((2048, 2048), - True, - device=self.device, - dtype=torch.bool), 1) + pcp_prefill_mask = self.attn_mask self.extra_long_seq_kwargs = { 'attn_mask_seqlens': attn_mask_seqlens,