From 392fd7239bb8ab12e0ad45c454f6941ead6a2a56 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 9 Jul 2025 09:12:03 +0800 Subject: [PATCH] [Misc] Add attention mask (#1673) Move attention mark from V0 to common place. - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/b942c094e3ab905aeb16f4136353f378e17159e8 Signed-off-by: wangxiyuan --- tests/ut/attention/test_attention_mask.py | 107 ++++++++++++++++++++++ vllm_ascend/attention/attention.py | 105 +-------------------- vllm_ascend/attention/attention_mask.py | 103 +++++++++++++++++++++ vllm_ascend/worker/eagle_proposer_v1.py | 4 +- vllm_ascend/worker/model_runner_v1.py | 4 +- 5 files changed, 216 insertions(+), 107 deletions(-) create mode 100644 tests/ut/attention/test_attention_mask.py create mode 100644 vllm_ascend/attention/attention_mask.py diff --git a/tests/ut/attention/test_attention_mask.py b/tests/ut/attention/test_attention_mask.py new file mode 100644 index 0000000..200c2a3 --- /dev/null +++ b/tests/ut/attention/test_attention_mask.py @@ -0,0 +1,107 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder + + +class TestAttentionMaskBuilder(TestBase): + + def test_init_attention_mask_builder(self): + # generate attention_mask_builder with float16 + attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, + dtype=torch.float16) + self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, + torch.float16) + self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000) + self.assertEqual(attention_mask_builder.attn_mask_cache.shape, + (1024, 1024)) + self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + + # generate attention_mask_builder with int8 + attention_mask_builder = AttentionMaskBuilder(max_seq_len=512, + dtype=torch.int8) + self.assertEqual(attention_mask_builder._seq_len_cached, 512) + self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, + torch.int8) + self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000) + self.assertEqual(attention_mask_builder.attn_mask_cache.shape, + (512, 512)) + self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], + torch.tensor(1, dtype=torch.int8)) + + def test_get_attn_mask(self): + # if the len is less than max_seq_len, the attn_mask_cache will not be updated + attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, + dtype=torch.float16) + attn_mask = attention_mask_builder.get_attn_mask( + max_seq_len=512, dtype=torch.float16, device=torch.device("cpu")) + self.assertEqual(attn_mask.shape, (512, 512)) + self.assertEqual(attn_mask[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + self.assertEqual(attention_mask_builder.attn_mask_cache.shape, + (1024, 1024)) + self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + + # if the len is greater than max_seq_len, the attn_mask_cache will be updated + attn_mask = attention_mask_builder.get_attn_mask( + max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu")) + self.assertEqual(attn_mask.shape, (2048, 2048)) + self.assertEqual(attn_mask[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + self.assertEqual(attention_mask_builder._seq_len_cached, 2048) + self.assertEqual(attention_mask_builder.attn_mask_cache.shape, + (2048, 2048)) + self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1], + torch.tensor(float("-inf"), dtype=torch.float16)) + + def test_get_splitfuse_attn_mask(self): + attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, + dtype=torch.float16) + attn_mask = attention_mask_builder.get_splitfuse_attn_mask( + seq_lens=[512], + query_lens=[512], + position=torch.tensor([0]), + dtype=torch.float16, + device=torch.device("cpu"), + ) + self.assertEqual(attn_mask.shape, (1, 512)) + self.assertEqual(attention_mask_builder._seq_len_cached, 1024) + + attn_mask = attention_mask_builder.get_splitfuse_attn_mask( + seq_lens=[2048], + query_lens=[1024], + position=torch.tensor([0]), + dtype=torch.float16, + device=torch.device("cpu"), + ) + self.assertEqual(attn_mask.shape, (1024, 2048)) + + attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, + dtype=torch.int8) + attn_mask = attention_mask_builder.get_splitfuse_attn_mask( + seq_lens=[512], + query_lens=[512], + position=torch.tensor([0]), + dtype=torch.int8, + device=torch.device("cpu"), + ) + self.assertEqual(attn_mask.shape, (1, 512)) diff --git a/vllm_ascend/attention/attention.py b/vllm_ascend/attention/attention.py index 4b545a1..35cb624 100644 --- a/vllm_ascend/attention/attention.py +++ b/vllm_ascend/attention/attention.py @@ -35,6 +35,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.ops.cache import concat_and_cache_mla from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, enable_custom_op, is_310p, nd_to_nz_2d) @@ -44,108 +45,6 @@ from vllm_ascend.worker.model_runner import ( _ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128] -def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None): - # Construct lower triangle matrix. - mask_flag = torch.tril( - torch.ones((max_seq_len, max_seq_len), - dtype=torch.bool)).view(max_seq_len, max_seq_len) - # Create upper triangle matrix used to mark mask positions. - mask_flag = ~mask_flag - # Currently for fp16 dtype, the mask value should be set to -inf. - # TODO: Eliminate this part in the future. - if mask_value is None: - if dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 - attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), - mask_flag, mask_value).to(dtype) - return attn_mask - - -class AttentionMaskBuilder: - - def __init__(self, attn_mask: torch.Tensor): - self._seq_len_cached = attn_mask.shape[0] - self.attn_mask_cache = attn_mask - self.splitfuse_mask_value = -10000 - - @classmethod - def initialize_from_len(cls, - max_seq_len: int, - dtype: torch.dtype = torch.float16, - mask_value: Optional[int] = None): - return cls(generate_attn_mask(max_seq_len, dtype, mask_value)) - - def update_attn_cache(self, seqlen: int, dtype: torch.dtype, - device: torch.device): - if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype: - self._seq_len_cached = seqlen - self.attn_mask_cache = generate_attn_mask(seqlen, dtype) - if self.attn_mask_cache.device != device: - self.attn_mask_cache = self.attn_mask_cache.to(device) - - def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, - device: torch.device): - self.update_attn_cache(max_seq_len, dtype, device) - return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous() - - def get_decode_attn_mask( - self, - input_lengths: torch.tensor, - max_s: int, - dtype: torch.dtype, - device: torch.device, - ): - self.update_attn_cache(max_s, dtype, device) - return (self.attn_mask_cache.index_select( - 0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous()) - - def get_splitfuse_attn_mask( - self, - seq_lens, - query_lens, - position, - dtype, - device, - ) -> torch.Tensor: - max_seq_len = max(seq_lens, default=0) - if max_seq_len <= self._seq_len_cached: - self.update_attn_cache(max_seq_len, dtype, device) - # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation - # is not the same. Fix this in the future when kernel is ready. - if self.attn_mask_cache.numel( - ) > 1 and self.attn_mask_cache[0][1] > 0: - attn_mask = self.get_attn_mask( # type: ignore - max_seq_len, dtype, device) - attn_mask *= -10000 - else: - attn_mask = self.attn_mask_cache - return torch.index_select(attn_mask, dim=0, - index=position)[:, :max_seq_len] - total_q_len = sum(query_lens) - attn_mask = torch.zeros((total_q_len, max_seq_len), - dtype=dtype, - device="cpu") - - current_row = 0 - for i in range(len(query_lens)): - seq_len = seq_lens[i] - q_len = query_lens[i] - context_len = seq_len - q_len - - assert context_len >= 0 - attn_mask[current_row:current_row + q_len, - context_len:] = self.splitfuse_mask_value - right_tensor = attn_mask[current_row:current_row + q_len, - context_len:seq_len] - right_tensor.masked_fill_( - right_tensor.tril() == self.splitfuse_mask_value, 0) - current_row += q_len - - return attn_mask.to(device, non_blocking=True) - - class AscendAttentionBackend(AttentionBackend): @staticmethod @@ -524,7 +423,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): self.compress_mask = None self.chunk_mask = None if AscendMetadataBuilder._attn_mask_builder is None: - AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( + AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder( 128, self.input_builder.runner.model_config.dtype) def _add_seq_group( diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py new file mode 100644 index 0000000..66ab414 --- /dev/null +++ b/vllm_ascend/attention/attention_mask.py @@ -0,0 +1,103 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def _generate_attn_mask(max_seq_len, dtype): + # Construct lower triangle matrix. + mask_flag = torch.tril( + torch.ones((max_seq_len, max_seq_len), + dtype=torch.bool)).view(max_seq_len, max_seq_len) + # Create upper triangle matrix used to mark mask positions. + mask_flag = ~mask_flag + # Currently for fp16 dtype, the mask value should be set to -inf. + # TODO: Eliminate this part in the future. + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), + mask_flag, mask_value).to(dtype) + return attn_mask + + +class AttentionMaskBuilder: + + def __init__( + self, + max_seq_len: int, + dtype: torch.dtype, + ): + attn_mask = _generate_attn_mask(max_seq_len, dtype) + + self._seq_len_cached = attn_mask.shape[0] + self.attn_mask_cache = attn_mask + self.splitfuse_mask_value = -10000 + + def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, + device: torch.device): + self._update_attn_cache(max_seq_len, dtype, device) + return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous() + + def get_splitfuse_attn_mask( + self, + seq_lens, + query_lens, + position, + dtype, + device, + ) -> torch.Tensor: + max_seq_len = max(seq_lens, default=0) + if max_seq_len <= self._seq_len_cached: + self._update_attn_cache(max_seq_len, dtype, device) + # FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation + # is not the same. Fix this in the future when kernel is ready. + if self.attn_mask_cache.numel( + ) > 1 and self.attn_mask_cache[0][1] > 0: + attn_mask = self.get_attn_mask( # type: ignore + max_seq_len, dtype, device) + attn_mask *= -10000 + else: + attn_mask = self.attn_mask_cache + return torch.index_select(attn_mask, dim=0, + index=position)[:, :max_seq_len] + total_q_len = sum(query_lens) + attn_mask = torch.zeros((total_q_len, max_seq_len), + dtype=dtype, + device="cpu") + current_row = 0 + for i in range(len(query_lens)): + seq_len = seq_lens[i] + q_len = query_lens[i] + context_len = seq_len - q_len + + assert context_len >= 0 + attn_mask[current_row:current_row + q_len, + context_len:] = self.splitfuse_mask_value + right_tensor = attn_mask[current_row:current_row + q_len, + context_len:seq_len] + right_tensor.masked_fill_( + right_tensor.tril() == self.splitfuse_mask_value, 0) + current_row += q_len + + return attn_mask.to(device, non_blocking=True) + + def _update_attn_cache(self, seqlen: int, dtype: torch.dtype, + device: torch.device): + if seqlen > self._seq_len_cached: + self._seq_len_cached = seqlen + self.attn_mask_cache = _generate_attn_mask(seqlen, dtype) + if self.attn_mask_cache.device != device: + self.attn_mask_cache = self.attn_mask_cache.to(device) diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index e2d22b8..d2fda0b 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -74,8 +74,8 @@ class EagleProposer: mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) self.attn_mask_len = min(self.model_config.max_model_len, int(mask_len)) - self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( - self.attn_mask_len, self.dtype) + self.attn_mask_builder = AttentionMaskBuilder(self.attn_mask_len, + self.dtype) def _make_attention_mask( self, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 27266db..38e1e33 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -325,8 +325,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): # the size of the pre-constructed mask matrix based on requirements. mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000) attn_mask_len = min(self.model_config.max_model_len, int(mask_len)) - self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len( - attn_mask_len, self.dtype) + self.attn_mask_builder = AttentionMaskBuilder(attn_mask_len, + self.dtype) self.new_kv_cache_bytes = -1 self.torchair_compiled_model = None # type: ignore