RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
There are various types of masks here, and some of them do not have a
caching mechanism. As a result, the masks need to be initialized for
each layer, leading to waste of video memory.
At the same time, we hope to standardize the management and usage of
masks.
So we have gathered all the masks into the AttentionMaskBuilder class.
Todo:
1. remove spec_attn_mask; @LICO1314
2. remove pcp_prefill_mask; @LICO1314
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Signed-off-by: daishixun <dsxsteven@sina.com>
Signed-off-by: lulina <lina.lulina@huawei.com>
Signed-off-by: zengran <zengran2@huawei.com>
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: 李少鹏 <lishaopeng21@huawei.com>
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com>
Co-authored-by: LuLina <lina.lulina@huawei.com>
Co-authored-by: zengzengran <zengran2@huawei.com>
Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com>
Co-authored-by: xuyexiong <xuyexiong@huawei.com>
Co-authored-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: Canlin Guo <canlinguosdu@gmail.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
54 lines
2.6 KiB
Python
54 lines
2.6 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|
|
|
|
|
class TestAttentionMaskBuilder(TestBase):
|
|
|
|
def test_get_attn_mask(self):
|
|
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
|
|
attention_mask_builder = AttentionMaskBuilder(torch.device("cpu"))
|
|
attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=512,
|
|
dtype=torch.float16)
|
|
self.assertEqual(attn_mask.shape, (512, 512))
|
|
self.assertEqual(attn_mask[0][-1],
|
|
torch.tensor(float("-inf"), dtype=torch.float16))
|
|
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
|
|
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
|
(512, 512))
|
|
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
|
torch.tensor(float("-inf"), dtype=torch.float16))
|
|
|
|
# if the len is greater than max_seq_len, the attn_mask_cache will be updated
|
|
attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=2048,
|
|
dtype=torch.float16)
|
|
self.assertEqual(attn_mask.shape, (2048, 2048))
|
|
self.assertEqual(attn_mask[0][-1],
|
|
torch.tensor(float("-inf"), dtype=torch.float16))
|
|
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
|
|
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
|
(2048, 2048))
|
|
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
|
torch.tensor(float("-inf"), dtype=torch.float16))
|
|
|
|
def test_get_splitfuse_attn_mask(self):
|
|
attention_mask_builder = AttentionMaskBuilder(torch.device("cpu"))
|
|
attn_mask = attention_mask_builder.get_splitfuse_attn_mask()
|
|
self.assertEqual(attn_mask.shape, (2048, 2048))
|