[Refactor] 2/N Unify all mask generation methods and cache mask (#4779)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
There are various types of masks here, and some of them do not have a
caching mechanism. As a result, the masks need to be initialized for
each layer, leading to waste of video memory.
At the same time, we hope to standardize the management and usage of
masks.
So we have gathered all the masks into the AttentionMaskBuilder class.
Todo:
1. remove spec_attn_mask; @LICO1314
2. remove pcp_prefill_mask; @LICO1314
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Signed-off-by: daishixun <dsxsteven@sina.com>
Signed-off-by: lulina <lina.lulina@huawei.com>
Signed-off-by: zengran <zengran2@huawei.com>
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: 李少鹏 <lishaopeng21@huawei.com>
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com>
Co-authored-by: LuLina <lina.lulina@huawei.com>
Co-authored-by: zengzengran <zengran2@huawei.com>
Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com>
Co-authored-by: xuyexiong <xuyexiong@huawei.com>
Co-authored-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: Canlin Guo <canlinguosdu@gmail.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
This commit is contained in:
@@ -21,58 +21,23 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
|
||||
class TestAttentionMaskBuilder(TestBase):
|
||||
|
||||
def test_init_attention_mask_builder(self):
|
||||
# generate attention_mask_builder with float16
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
|
||||
torch.float16)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(1024, 1024))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
# generate attention_mask_builder with bfloat16
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
|
||||
dtype=torch.bfloat16)
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(2048, 2048))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(1, dtype=torch.bfloat16))
|
||||
|
||||
def test_get_mask_scale_factor(self):
|
||||
# supported data types
|
||||
self.assertEqual(
|
||||
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
|
||||
self.assertEqual(
|
||||
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
|
||||
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
|
||||
# Otherwise raise ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
|
||||
|
||||
def test_get_attn_mask(self):
|
||||
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
attn_mask = attention_mask_builder.get_attn_mask(
|
||||
max_seq_len=512, dtype=torch.float16, device=torch.device("cpu"))
|
||||
attention_mask_builder = AttentionMaskBuilder(torch.device("cpu"))
|
||||
attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=512,
|
||||
dtype=torch.float16)
|
||||
self.assertEqual(attn_mask.shape, (512, 512))
|
||||
self.assertEqual(attn_mask[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(1024, 1024))
|
||||
(512, 512))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
# if the len is greater than max_seq_len, the attn_mask_cache will be updated
|
||||
attn_mask = attention_mask_builder.get_attn_mask(
|
||||
max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu"))
|
||||
attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=2048,
|
||||
dtype=torch.float16)
|
||||
self.assertEqual(attn_mask.shape, (2048, 2048))
|
||||
self.assertEqual(attn_mask[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
@@ -83,13 +48,6 @@ class TestAttentionMaskBuilder(TestBase):
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
def test_get_splitfuse_attn_mask(self):
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens=torch.tensor([10, 20, 100]),
|
||||
position=torch.tensor([7, 8, 9, 18, 19, 99]),
|
||||
dtype=torch.float16,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
attention_mask_builder = AttentionMaskBuilder(torch.device("cpu"))
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask()
|
||||
self.assertEqual(attn_mask.shape, (2048, 2048))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
|
||||
Reference in New Issue
Block a user