[Misc] Add attention mask (#1673)
Move attention mark from V0 to common place.
- vLLM version: v0.9.2
- vLLM main:
b942c094e3
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
107
tests/ut/attention/test_attention_mask.py
Normal file
107
tests/ut/attention/test_attention_mask.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
|
||||
|
||||
class TestAttentionMaskBuilder(TestBase):
|
||||
|
||||
def test_init_attention_mask_builder(self):
|
||||
# generate attention_mask_builder with float16
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
|
||||
torch.float16)
|
||||
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(1024, 1024))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
# generate attention_mask_builder with int8
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=512,
|
||||
dtype=torch.int8)
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
|
||||
torch.int8)
|
||||
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(512, 512))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(1, dtype=torch.int8))
|
||||
|
||||
def test_get_attn_mask(self):
|
||||
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
attn_mask = attention_mask_builder.get_attn_mask(
|
||||
max_seq_len=512, dtype=torch.float16, device=torch.device("cpu"))
|
||||
self.assertEqual(attn_mask.shape, (512, 512))
|
||||
self.assertEqual(attn_mask[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(1024, 1024))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
# if the len is greater than max_seq_len, the attn_mask_cache will be updated
|
||||
attn_mask = attention_mask_builder.get_attn_mask(
|
||||
max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu"))
|
||||
self.assertEqual(attn_mask.shape, (2048, 2048))
|
||||
self.assertEqual(attn_mask[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(2048, 2048))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
def test_get_splitfuse_attn_mask(self):
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens=[512],
|
||||
query_lens=[512],
|
||||
position=torch.tensor([0]),
|
||||
dtype=torch.float16,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.assertEqual(attn_mask.shape, (1, 512))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens=[2048],
|
||||
query_lens=[1024],
|
||||
position=torch.tensor([0]),
|
||||
dtype=torch.float16,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.assertEqual(attn_mask.shape, (1024, 2048))
|
||||
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.int8)
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens=[512],
|
||||
query_lens=[512],
|
||||
position=torch.tensor([0]),
|
||||
dtype=torch.int8,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.assertEqual(attn_mask.shape, (1, 512))
|
||||
@@ -35,6 +35,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.ops.cache import concat_and_cache_mla
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
|
||||
enable_custom_op, is_310p, nd_to_nz_2d)
|
||||
@@ -44,108 +45,6 @@ from vllm_ascend.worker.model_runner import (
|
||||
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
|
||||
|
||||
|
||||
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
|
||||
# Construct lower triangle matrix.
|
||||
mask_flag = torch.tril(
|
||||
torch.ones((max_seq_len, max_seq_len),
|
||||
dtype=torch.bool)).view(max_seq_len, max_seq_len)
|
||||
# Create upper triangle matrix used to mark mask positions.
|
||||
mask_flag = ~mask_flag
|
||||
# Currently for fp16 dtype, the mask value should be set to -inf.
|
||||
# TODO: Eliminate this part in the future.
|
||||
if mask_value is None:
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
||||
mask_flag, mask_value).to(dtype)
|
||||
return attn_mask
|
||||
|
||||
|
||||
class AttentionMaskBuilder:
|
||||
|
||||
def __init__(self, attn_mask: torch.Tensor):
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
self.splitfuse_mask_value = -10000
|
||||
|
||||
@classmethod
|
||||
def initialize_from_len(cls,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
mask_value: Optional[int] = None):
|
||||
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
|
||||
|
||||
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
|
||||
self._seq_len_cached = seqlen
|
||||
self.attn_mask_cache = generate_attn_mask(seqlen, dtype)
|
||||
if self.attn_mask_cache.device != device:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(device)
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
self.update_attn_cache(max_seq_len, dtype, device)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
||||
|
||||
def get_decode_attn_mask(
|
||||
self,
|
||||
input_lengths: torch.tensor,
|
||||
max_s: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.update_attn_cache(max_s, dtype, device)
|
||||
return (self.attn_mask_cache.index_select(
|
||||
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
position,
|
||||
dtype,
|
||||
device,
|
||||
) -> torch.Tensor:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
if max_seq_len <= self._seq_len_cached:
|
||||
self.update_attn_cache(max_seq_len, dtype, device)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
if self.attn_mask_cache.numel(
|
||||
) > 1 and self.attn_mask_cache[0][1] > 0:
|
||||
attn_mask = self.get_attn_mask( # type: ignore
|
||||
max_seq_len, dtype, device)
|
||||
attn_mask *= -10000
|
||||
else:
|
||||
attn_mask = self.attn_mask_cache
|
||||
return torch.index_select(attn_mask, dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
total_q_len = sum(query_lens)
|
||||
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
||||
dtype=dtype,
|
||||
device="cpu")
|
||||
|
||||
current_row = 0
|
||||
for i in range(len(query_lens)):
|
||||
seq_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
context_len = seq_len - q_len
|
||||
|
||||
assert context_len >= 0
|
||||
attn_mask[current_row:current_row + q_len,
|
||||
context_len:] = self.splitfuse_mask_value
|
||||
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||
context_len:seq_len]
|
||||
right_tensor.masked_fill_(
|
||||
right_tensor.tril() == self.splitfuse_mask_value, 0)
|
||||
current_row += q_len
|
||||
|
||||
return attn_mask.to(device, non_blocking=True)
|
||||
|
||||
|
||||
class AscendAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
@@ -524,7 +423,7 @@ class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]):
|
||||
self.compress_mask = None
|
||||
self.chunk_mask = None
|
||||
if AscendMetadataBuilder._attn_mask_builder is None:
|
||||
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder(
|
||||
128, self.input_builder.runner.model_config.dtype)
|
||||
|
||||
def _add_seq_group(
|
||||
|
||||
103
vllm_ascend/attention/attention_mask.py
Normal file
103
vllm_ascend/attention/attention_mask.py
Normal file
@@ -0,0 +1,103 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
|
||||
def _generate_attn_mask(max_seq_len, dtype):
|
||||
# Construct lower triangle matrix.
|
||||
mask_flag = torch.tril(
|
||||
torch.ones((max_seq_len, max_seq_len),
|
||||
dtype=torch.bool)).view(max_seq_len, max_seq_len)
|
||||
# Create upper triangle matrix used to mark mask positions.
|
||||
mask_flag = ~mask_flag
|
||||
# Currently for fp16 dtype, the mask value should be set to -inf.
|
||||
# TODO: Eliminate this part in the future.
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
|
||||
mask_flag, mask_value).to(dtype)
|
||||
return attn_mask
|
||||
|
||||
|
||||
class AttentionMaskBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
attn_mask = _generate_attn_mask(max_seq_len, dtype)
|
||||
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
self.splitfuse_mask_value = -10000
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
self._update_attn_cache(max_seq_len, dtype, device)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
position,
|
||||
dtype,
|
||||
device,
|
||||
) -> torch.Tensor:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
if max_seq_len <= self._seq_len_cached:
|
||||
self._update_attn_cache(max_seq_len, dtype, device)
|
||||
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
||||
# is not the same. Fix this in the future when kernel is ready.
|
||||
if self.attn_mask_cache.numel(
|
||||
) > 1 and self.attn_mask_cache[0][1] > 0:
|
||||
attn_mask = self.get_attn_mask( # type: ignore
|
||||
max_seq_len, dtype, device)
|
||||
attn_mask *= -10000
|
||||
else:
|
||||
attn_mask = self.attn_mask_cache
|
||||
return torch.index_select(attn_mask, dim=0,
|
||||
index=position)[:, :max_seq_len]
|
||||
total_q_len = sum(query_lens)
|
||||
attn_mask = torch.zeros((total_q_len, max_seq_len),
|
||||
dtype=dtype,
|
||||
device="cpu")
|
||||
current_row = 0
|
||||
for i in range(len(query_lens)):
|
||||
seq_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
context_len = seq_len - q_len
|
||||
|
||||
assert context_len >= 0
|
||||
attn_mask[current_row:current_row + q_len,
|
||||
context_len:] = self.splitfuse_mask_value
|
||||
right_tensor = attn_mask[current_row:current_row + q_len,
|
||||
context_len:seq_len]
|
||||
right_tensor.masked_fill_(
|
||||
right_tensor.tril() == self.splitfuse_mask_value, 0)
|
||||
current_row += q_len
|
||||
|
||||
return attn_mask.to(device, non_blocking=True)
|
||||
|
||||
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
if seqlen > self._seq_len_cached:
|
||||
self._seq_len_cached = seqlen
|
||||
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
|
||||
if self.attn_mask_cache.device != device:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(device)
|
||||
@@ -74,8 +74,8 @@ class EagleProposer:
|
||||
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||
self.attn_mask_len = min(self.model_config.max_model_len,
|
||||
int(mask_len))
|
||||
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
self.attn_mask_len, self.dtype)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.attn_mask_len,
|
||||
self.dtype)
|
||||
|
||||
def _make_attention_mask(
|
||||
self,
|
||||
|
||||
@@ -325,8 +325,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# the size of the pre-constructed mask matrix based on requirements.
|
||||
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||
attn_mask_len = min(self.model_config.max_model_len, int(mask_len))
|
||||
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||
attn_mask_len, self.dtype)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(attn_mask_len,
|
||||
self.dtype)
|
||||
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user