[Refactor] 2/N Unify all mask generation methods and cache mask (#4779)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

Reason:

There are various types of masks here, and some of them do not have a
caching mechanism. As a result, the masks need to be initialized for
each layer, leading to waste of video memory.

At the same time, we hope to standardize the management and usage of
masks.

So we have gathered all the masks into the AttentionMaskBuilder class.

Todo:
1. remove spec_attn_mask;  @LICO1314
2. remove pcp_prefill_mask; @LICO1314


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Signed-off-by: daishixun <dsxsteven@sina.com>
Signed-off-by: lulina <lina.lulina@huawei.com>
Signed-off-by: zengran <zengran2@huawei.com>
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: 李少鹏 <lishaopeng21@huawei.com>
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com>
Co-authored-by: LuLina <lina.lulina@huawei.com>
Co-authored-by: zengzengran <zengran2@huawei.com>
Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com>
Co-authored-by: xuyexiong <xuyexiong@huawei.com>
Co-authored-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: Canlin Guo <canlinguosdu@gmail.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
This commit is contained in:
weijinqian0
2025-12-09 18:51:00 +08:00
committed by GitHub
parent dee00d0de3
commit c331503677
6 changed files with 66 additions and 174 deletions

View File

@@ -21,58 +21,23 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
class TestAttentionMaskBuilder(TestBase):
def test_init_attention_mask_builder(self):
# generate attention_mask_builder with float16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.float16)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
# generate attention_mask_builder with bfloat16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(2048, 2048))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(1, dtype=torch.bfloat16))
def test_get_mask_scale_factor(self):
# supported data types
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
# Otherwise raise ValueError
with self.assertRaises(ValueError):
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
def test_get_attn_mask(self):
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=512, dtype=torch.float16, device=torch.device("cpu"))
attention_mask_builder = AttentionMaskBuilder(torch.device("cpu"))
attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=512,
dtype=torch.float16)
self.assertEqual(attn_mask.shape, (512, 512))
self.assertEqual(attn_mask[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024))
(512, 512))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
# if the len is greater than max_seq_len, the attn_mask_cache will be updated
attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu"))
attn_mask = attention_mask_builder.get_attn_mask(max_seq_len=2048,
dtype=torch.float16)
self.assertEqual(attn_mask.shape, (2048, 2048))
self.assertEqual(attn_mask[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
@@ -83,13 +48,6 @@ class TestAttentionMaskBuilder(TestBase):
torch.tensor(float("-inf"), dtype=torch.float16))
def test_get_splitfuse_attn_mask(self):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.float16,
device=torch.device("cpu"),
)
attention_mask_builder = AttentionMaskBuilder(torch.device("cpu"))
attn_mask = attention_mask_builder.get_splitfuse_attn_mask()
self.assertEqual(attn_mask.shape, (2048, 2048))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)

View File

@@ -31,66 +31,54 @@ def _generate_attn_mask(max_seq_len, dtype):
class AttentionMaskBuilder:
def __init__(
self,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device = None,
):
# NOTE: The device argument specifies the target NPU
# to be used for the newly added FIA operator.
# Only pass this parameter when using the new FIA operator.
attn_mask = _generate_attn_mask(max_seq_len, dtype)
self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask
def __init__(self, device: torch.device):
self.attn_mask_cache = None
self._seq_len_cached = 0
self.device = device
self.pooling_mask = None
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(assigned_mask_dim, assigned_mask_dim),
diagonal=1).to(torch.int8).to(device)
self.mla_mask = None
self.chunked_prefill_attn_mask = None
self.pcp_mla_mask = None
@staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
if dtype == torch.float16:
mask_scale_factor = 1
elif dtype == torch.bfloat16:
mask_scale_factor = -10000
else:
raise ValueError(
"The current operation now only supports data types: torch.float16 and "
"torch.bfloat16. Please ensure the input is of one of these types."
)
return mask_scale_factor
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
device: torch.device):
self._update_attn_cache(max_seq_len, dtype)
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype):
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
self.attn_mask_cache = _generate_attn_mask(max_seq_len, dtype)
self._seq_len_cached = max_seq_len
assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask."
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(device, non_blocking=True)
).to(self.device, non_blocking=True)
def get_pooling_mask(self, device):
def get_pooling_mask(self):
if self.pooling_mask is None:
# the compressed attention mask for npu_fusion_attention sparse mode 4
self.pooling_mask = torch.triu(torch.ones(
2048, 2048), diagonal=1).to(torch.bool).to(device,
2048, 2048), diagonal=1).to(torch.bool).to(self.device,
non_blocking=True)
return self.pooling_mask
def get_splitfuse_attn_mask(
self,
seq_lens: torch.Tensor = None,
position: torch.Tensor = None,
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
def get_splitfuse_attn_mask(self) -> torch.Tensor:
if self.chunked_prefill_attn_mask is None:
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(2048,
2048), diagonal=1).to(torch.int8).to(self.device)
return self.chunked_prefill_attn_mask
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:
self._seq_len_cached = seqlen
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor:
if self.mla_mask is None or self.mla_mask.dtype != dtype:
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
prefill_mask = torch.triu(
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
self.mla_mask = torch.where(prefill_mask == 1, mask_value,
0).to(dtype)
return self.mla_mask
def get_pcp_mla_mask(self, dtype: torch.dtype):
if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype:
self.pcp_mla_mask = torch.triu(
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
return self.pcp_mla_mask

View File

@@ -202,7 +202,6 @@ class AscendMLAMetadataBuilder:
understand this class
"""
# _attn_mask_builder = None
def __init__(self,
kv_cache_spec,
layer_names,
@@ -862,7 +861,6 @@ class AscendMLAImpl(MLAAttentionImpl):
vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
self.prefill_mask = None
self.speculative_config = vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
@@ -1167,10 +1165,7 @@ class AscendMLAImpl(MLAAttentionImpl):
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
if self.pcp_size > 1:
mask = attn_metadata.prefill.pcp_metadata.pcp_prefill_mask
else:
mask = self.prefill_mask
mask = attn_metadata.attn_mask
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
@@ -1214,24 +1209,12 @@ class AscendMLAImpl(MLAAttentionImpl):
num_tokens,
dtype=torch.float32,
device=q_nope.device)
if self.prefill_mask is None:
if q_nope.dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
prefill_mask = torch.triu(
torch.ones(self.ring_mla_mask_size,
self.ring_mla_mask_size,
device=q_nope.device,
dtype=q_nope.dtype), 1)
self.prefill_mask = torch.where(prefill_mask == 1, mask_value,
0).to(q_nope.dtype)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=self.prefill_mask,
mask=attn_metadata.attn_mask,
seqlen=attn_metadata.prefill.query_lens,
head_num=self.num_heads,
kv_head_num=self.num_heads,

View File

@@ -88,8 +88,6 @@ class AscendCommonAttentionMetadata:
attn_mask: torch.Tensor = None
fia_attn_mask: torch.Tensor = None
spec_attn_mask: torch.Tensor = None
attn_state: Any = None

View File

@@ -77,9 +77,7 @@ class EagleProposer(Proposer):
1,
device=device,
dtype=torch.int32)
attn_mask_len = self.vllm_config.model_config.max_model_len
self.attn_mask_builder = AttentionMaskBuilder(
attn_mask_len, self.vllm_config.model_config.dtype, device=device)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set(
@@ -570,9 +568,7 @@ class EagleProposer(Proposer):
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
attn_metadata.seq_lens, positions_cpu,
self.vllm_config.model_config.dtype, self.device)
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
attn_metadata.attn_mask = attn_mask
attn_metadata.block_tables = block_table.to(device)

View File

@@ -378,12 +378,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse)
if self.pcp_size > 1:
self.attn_mask_builder = None
else:
self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
self.device)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self._set_up_drafter()
@@ -651,10 +646,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to(self.device)
self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
)
if get_pp_group().is_last_rank:
self.drafter = self._get_drafter()
self.rejection_sampler = AscendRejectionSampler(self.sampler)
@@ -1033,21 +1026,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
return tuple(tasks)
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
def _make_attention_mask(self, attn_state) -> torch.Tensor:
# pcp situation.
if self.pcp_size > 1:
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# dcp situation.
if self.dcp_size > 1:
return self.attn_mask_builder.get_splitfuse_attn_mask()
if self.vllm_config.model_config.use_mla:
return None
# Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device)
return self.attn_mask_builder.get_pooling_mask()
if self.vllm_config.model_config.use_mla:
if self.pcp_size > 1:
return self.attn_mask_builder.get_pcp_mla_mask(self.dtype)
# mla prefill
if attn_state != AscendAttentionState.DecodeOnly:
return self.attn_mask_builder.get_mla_mask(self.dtype)
return self.attn_mask_builder.get_splitfuse_attn_mask()
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
@@ -1668,16 +1660,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
# Make Attention metadata
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.attn_mask = self._make_attention_mask(attn_state)
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
@@ -2840,12 +2825,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
assigned_mask_dim = 2048
self.attn_mask = torch.triu(torch.ones(assigned_mask_dim,
assigned_mask_dim),
diagonal=1).to(torch.int8).to(
self.device)
self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -4499,18 +4479,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
dtype=torch.int32)
if self.vllm_config.model_config.use_mla:
pcp_prefill_mask = torch.triu(
torch.ones(512,
512,
device=self.device,
dtype=self.dtype), 1)
else:
pcp_prefill_mask = torch.triu(
torch.full((2048, 2048),
True,
device=self.device,
dtype=torch.bool), 1)
pcp_prefill_mask = self.attn_mask
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,