diff --git a/tests/ut/attention/test_attention_mask.py b/tests/ut/attention/test_attention_mask.py index 9bd4cd0e..cabffda4 100644 --- a/tests/ut/attention/test_attention_mask.py +++ b/tests/ut/attention/test_attention_mask.py @@ -21,58 +21,23 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder class TestAttentionMaskBuilder(TestBase): - def test_init_attention_mask_builder(self): - # generate attention_mask_builder with float16 - attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, - dtype=torch.float16) - self.assertEqual(attention_mask_builder._seq_len_cached, 1024) - self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, - torch.float16) - self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (1024, 1024)) - self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], - torch.tensor(float("-inf"), dtype=torch.float16)) - - # generate attention_mask_builder with bfloat16 - attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048, - dtype=torch.bfloat16) - self.assertEqual(attention_mask_builder._seq_len_cached, 2048) - self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, - torch.bfloat16) - self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (2048, 2048)) - self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], - torch.tensor(1, dtype=torch.bfloat16)) - - def test_get_mask_scale_factor(self): - # supported data types - self.assertEqual( - AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1) - self.assertEqual( - AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000) - # mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16 - # Otherwise raise ValueError - with self.assertRaises(ValueError): - AttentionMaskBuilder.get_mask_scale_factor(torch.int8) - def test_get_attn_mask(self): # if the len is less than max_seq_len, the attn_mask_cache will not be updated - attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, - dtype=torch.float16) - attn_mask = attention_mask_builder.get_attn_mask( - max_seq_len=512, dtype=torch.float16, device=torch.device("cpu")) + attention_mask_builder = AttentionMaskBuilder(torch.device("cpu")) + attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=512, + dtype=torch.float16) self.assertEqual(attn_mask.shape, (512, 512)) self.assertEqual(attn_mask[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) - self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + self.assertEqual(attention_mask_builder._seq_len_cached, 512) self.assertEqual(attention_mask_builder.attn_mask_cache.shape, - (1024, 1024)) + (512, 512)) self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) # if the len is greater than max_seq_len, the attn_mask_cache will be updated - attn_mask = attention_mask_builder.get_attn_mask( - max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu")) + attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=2048, + dtype=torch.float16) self.assertEqual(attn_mask.shape, (2048, 2048)) self.assertEqual(attn_mask[0][-1], torch.tensor(float("-inf"), dtype=torch.float16)) @@ -83,13 +48,6 @@ class TestAttentionMaskBuilder(TestBase): torch.tensor(float("-inf"), dtype=torch.float16)) def test_get_splitfuse_attn_mask(self): - attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, - dtype=torch.float16) - attn_mask = attention_mask_builder.get_splitfuse_attn_mask( - seq_lens=torch.tensor([10, 20, 100]), - position=torch.tensor([7, 8, 9, 18, 19, 99]), - dtype=torch.float16, - device=torch.device("cpu"), - ) + attention_mask_builder = AttentionMaskBuilder(torch.device("cpu")) + attn_mask = attention_mask_builder.get_splitfuse_attn_mask() self.assertEqual(attn_mask.shape, (2048, 2048)) - self.assertEqual(attention_mask_builder._seq_len_cached, 1024) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 2c963b5c..c1322351 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -31,66 +31,54 @@ def _generate_attn_mask(max_seq_len, dtype): class AttentionMaskBuilder: - def __init__( - self, - max_seq_len: int, - dtype: torch.dtype, - device: torch.device = None, - ): - # NOTE: The device argument specifies the target NPU - # to be used for the newly added FIA operator. - # Only pass this parameter when using the new FIA operator. - - attn_mask = _generate_attn_mask(max_seq_len, dtype) - - self._seq_len_cached = attn_mask.shape[0] - self.attn_mask_cache = attn_mask + def __init__(self, device: torch.device): + self.attn_mask_cache = None + self._seq_len_cached = 0 self.device = device self.pooling_mask = None - assigned_mask_dim = 2048 - self.chunked_prefill_attn_mask = torch.triu( - torch.ones(assigned_mask_dim, assigned_mask_dim), - diagonal=1).to(torch.int8).to(device) + self.mla_mask = None + self.chunked_prefill_attn_mask = None + self.pcp_mla_mask = None - @staticmethod - def get_mask_scale_factor(dtype: torch.dtype = torch.float16): - if dtype == torch.float16: - mask_scale_factor = 1 - elif dtype == torch.bfloat16: - mask_scale_factor = -10000 - else: - raise ValueError( - "The current operation now only supports data types: torch.float16 and " - "torch.bfloat16. Please ensure the input is of one of these types." - ) - return mask_scale_factor - - def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, - device: torch.device): - self._update_attn_cache(max_seq_len, dtype) + def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype): + if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached: + self.attn_mask_cache = _generate_attn_mask(max_seq_len, dtype) + self._seq_len_cached = max_seq_len + assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask." + if self.attn_mask_cache.dtype != dtype: + self.attn_mask_cache = self.attn_mask_cache.to(dtype) return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( - ).to(device, non_blocking=True) + ).to(self.device, non_blocking=True) - def get_pooling_mask(self, device): + def get_pooling_mask(self): if self.pooling_mask is None: # the compressed attention mask for npu_fusion_attention sparse mode 4 self.pooling_mask = torch.triu(torch.ones( - 2048, 2048), diagonal=1).to(torch.bool).to(device, + 2048, 2048), diagonal=1).to(torch.bool).to(self.device, non_blocking=True) return self.pooling_mask - def get_splitfuse_attn_mask( - self, - seq_lens: torch.Tensor = None, - position: torch.Tensor = None, - dtype: torch.dtype = None, - device: torch.device = None, - ) -> torch.Tensor: + def get_splitfuse_attn_mask(self) -> torch.Tensor: + if self.chunked_prefill_attn_mask is None: + self.chunked_prefill_attn_mask = torch.triu( + torch.ones(2048, + 2048), diagonal=1).to(torch.int8).to(self.device) return self.chunked_prefill_attn_mask - def _update_attn_cache(self, seqlen: int, dtype: torch.dtype): - if seqlen > self._seq_len_cached: - self._seq_len_cached = seqlen - self.attn_mask_cache = _generate_attn_mask(seqlen, dtype) - if self.attn_mask_cache.dtype != dtype: - self.attn_mask_cache = self.attn_mask_cache.to(dtype) + def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor: + if self.mla_mask is None or self.mla_mask.dtype != dtype: + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + prefill_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=dtype), 1) + self.mla_mask = torch.where(prefill_mask == 1, mask_value, + 0).to(dtype) + return self.mla_mask + + def get_pcp_mla_mask(self, dtype: torch.dtype): + if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype: + self.pcp_mla_mask = torch.triu( + torch.ones(512, 512, device=self.device, dtype=dtype), 1) + return self.pcp_mla_mask diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index cd2111fa..a6a6447c 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -202,7 +202,6 @@ class AscendMLAMetadataBuilder: understand this class """ - # _attn_mask_builder = None def __init__(self, kv_cache_spec, layer_names, @@ -862,7 +861,6 @@ class AscendMLAImpl(MLAAttentionImpl): vllm_config = get_current_vllm_config() self.ring_mla_mask_size = 512 - self.prefill_mask = None self.speculative_config = vllm_config.speculative_config self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO @@ -1167,10 +1165,7 @@ class AscendMLAImpl(MLAAttentionImpl): .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) - if self.pcp_size > 1: - mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask - else: - mask = self.prefill_mask + mask = attn_metadata.attn_mask torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe, @@ -1214,24 +1209,12 @@ class AscendMLAImpl(MLAAttentionImpl): num_tokens, dtype=torch.float32, device=q_nope.device) - if self.prefill_mask is None: - if q_nope.dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 - prefill_mask = torch.triu( - torch.ones(self.ring_mla_mask_size, - self.ring_mla_mask_size, - device=q_nope.device, - dtype=q_nope.dtype), 1) - self.prefill_mask = torch.where(prefill_mask == 1, mask_value, - 0).to(q_nope.dtype) torch_npu.atb.npu_ring_mla(q_nope=q_nope, q_rope=q_pe, k_nope=k_nope, k_rope=k_pe, value=value, - mask=self.prefill_mask, + mask=attn_metadata.attn_mask, seqlen=attn_metadata.prefill.query_lens, head_num=self.num_heads, kv_head_num=self.num_heads, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index e929dacc..a2f71de7 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -88,8 +88,6 @@ class AscendCommonAttentionMetadata: attn_mask: torch.Tensor = None - fia_attn_mask: torch.Tensor = None - spec_attn_mask: torch.Tensor = None attn_state: Any = None diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index af7d3689..27a7f717 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -77,9 +77,7 @@ class EagleProposer(Proposer): 1, device=device, dtype=torch.int32) - attn_mask_len = self.vllm_config.model_config.max_model_len - self.attn_mask_builder = AttentionMaskBuilder( - attn_mask_len, self.vllm_config.model_config.dtype, device=device) + self.attn_mask_builder = AttentionMaskBuilder(self.device) def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( @@ -570,9 +568,7 @@ class EagleProposer(Proposer): self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states - attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( - attn_metadata.seq_lens, positions_cpu, - self.vllm_config.model_config.dtype, self.device) + attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() attn_metadata.attn_mask = attn_mask attn_metadata.block_tables = block_table.to(device) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 56cef32c..f6fded7d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -378,12 +378,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.block_size, use_mla=self.model_config.use_mla, use_sparse=self.use_sparse) - if self.pcp_size > 1: - self.attn_mask_builder = None - else: - self.attn_mask_builder = AttentionMaskBuilder( - self.scheduler_config.max_num_batched_tokens, self.dtype, - self.device) + self.attn_mask_builder = AttentionMaskBuilder(self.device) self._set_up_drafter() @@ -651,10 +646,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): spec_token_num = self.speculative_config.num_speculative_tokens assert spec_token_num > 0 self.decode_token_per_req = 1 + spec_token_num - self.spec_attn_mask = torch.triu(torch.ones(2048, - 2048, - dtype=torch.bool), - diagonal=1).to(self.device) + self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask( + ) if get_pp_group().is_last_rank: self.drafter = self._get_drafter() self.rejection_sampler = AscendRejectionSampler(self.sampler) @@ -1033,21 +1026,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): return tuple(tasks) - def _make_attention_mask(self, seq_lens, position, - attn_state) -> torch.Tensor: + def _make_attention_mask(self, attn_state) -> torch.Tensor: # pcp situation. - if self.pcp_size > 1: - return None if self.attn_mask_builder is None: raise ValueError("Attn mask builder is None") - # dcp situation. - if self.dcp_size > 1: - return self.attn_mask_builder.get_splitfuse_attn_mask() - if self.vllm_config.model_config.use_mla: - return None # Pooling situation. if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": - return self.attn_mask_builder.get_pooling_mask(self.device) + return self.attn_mask_builder.get_pooling_mask() + + if self.vllm_config.model_config.use_mla: + if self.pcp_size > 1: + return self.attn_mask_builder.get_pcp_mla_mask(self.dtype) + # mla prefill + if attn_state != AscendAttentionState.DecodeOnly: + return self.attn_mask_builder.get_mla_mask(self.dtype) return self.attn_mask_builder.get_splitfuse_attn_mask() def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): @@ -1668,16 +1660,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.positions[:num_input_tokens].copy_( self.positions_cpu[:num_input_tokens], non_blocking=True) - # Make Attention metadata - positions_cpu = self.positions_cpu[:num_input_tokens] - positions = self.positions[:num_input_tokens] - seq_lens_cpu = self.seq_lens_cpu[:num_reqs] - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) - self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu, - position=positions_cpu, - attn_state=attn_state) + self.attn_mask = self._make_attention_mask(attn_state) self.attn_state = attn_state # type: ignore self.with_prefill = with_prefill @@ -2840,12 +2825,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.query_start_loc_cpu[1:num_reqs + 1] = torch.Tensor(cu_num_tokens) self.query_lens = torch.from_numpy(num_scheduled_tokens) - - assigned_mask_dim = 2048 - self.attn_mask = torch.triu(torch.ones(assigned_mask_dim, - assigned_mask_dim), - diagonal=1).to(torch.int8).to( - self.device) + self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() num_computed_tokens_cpu = ( self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) @@ -4499,18 +4479,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): tail_attn_nomask_seqlens = torch.tensor( [chunk_seqlens, kv_with_q_tail_nomask_seqlens], dtype=torch.int32) - if self.vllm_config.model_config.use_mla: - pcp_prefill_mask = torch.triu( - torch.ones(512, - 512, - device=self.device, - dtype=self.dtype), 1) - else: - pcp_prefill_mask = torch.triu( - torch.full((2048, 2048), - True, - device=self.device, - dtype=torch.bool), 1) + pcp_prefill_mask = self.attn_mask self.extra_long_seq_kwargs = { 'attn_mask_seqlens': attn_mask_seqlens,