[Refactor] 2/N Unify all mask generation methods and cache mask (#4779)

RFC: https://github.com/vllm-project/vllm-ascend/issues/4629

Reason:

There are various types of masks here, and some of them do not have a
caching mechanism. As a result, the masks need to be initialized for
each layer, leading to waste of video memory.

At the same time, we hope to standardize the management and usage of
masks.

So we have gathered all the masks into the AttentionMaskBuilder class.

Todo:
1. remove spec_attn_mask;  @LICO1314
2. remove pcp_prefill_mask; @LICO1314


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Signed-off-by: daishixun <dsxsteven@sina.com>
Signed-off-by: lulina <lina.lulina@huawei.com>
Signed-off-by: zengran <zengran2@huawei.com>
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: 李少鹏 <lishaopeng21@huawei.com>
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com>
Co-authored-by: LuLina <lina.lulina@huawei.com>
Co-authored-by: zengzengran <zengran2@huawei.com>
Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com>
Co-authored-by: xuyexiong <xuyexiong@huawei.com>
Co-authored-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: Canlin Guo <canlinguosdu@gmail.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
This commit is contained in:
weijinqian0
2025-12-09 18:51:00 +08:00
committed by GitHub
parent dee00d0de3
commit c331503677
6 changed files with 66 additions and 174 deletions

View File

@@ -378,12 +378,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.block_size,
use_mla=self.model_config.use_mla,
use_sparse=self.use_sparse)
if self.pcp_size > 1:
self.attn_mask_builder = None
else:
self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
self.device)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self._set_up_drafter()
@@ -651,10 +646,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
diagonal=1).to(self.device)
self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
)
if get_pp_group().is_last_rank:
self.drafter = self._get_drafter()
self.rejection_sampler = AscendRejectionSampler(self.sampler)
@@ -1033,21 +1026,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
return tuple(tasks)
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
def _make_attention_mask(self, attn_state) -> torch.Tensor:
# pcp situation.
if self.pcp_size > 1:
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# dcp situation.
if self.dcp_size > 1:
return self.attn_mask_builder.get_splitfuse_attn_mask()
if self.vllm_config.model_config.use_mla:
return None
# Pooling situation.
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device)
return self.attn_mask_builder.get_pooling_mask()
if self.vllm_config.model_config.use_mla:
if self.pcp_size > 1:
return self.attn_mask_builder.get_pcp_mla_mask(self.dtype)
# mla prefill
if attn_state != AscendAttentionState.DecodeOnly:
return self.attn_mask_builder.get_mla_mask(self.dtype)
return self.attn_mask_builder.get_splitfuse_attn_mask()
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
@@ -1668,16 +1660,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
# Make Attention metadata
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.attn_mask = self._make_attention_mask(attn_state)
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
@@ -2840,12 +2825,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
self.query_start_loc_cpu[1:num_reqs +
1] = torch.Tensor(cu_num_tokens)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
assigned_mask_dim = 2048
self.attn_mask = torch.triu(torch.ones(assigned_mask_dim,
assigned_mask_dim),
diagonal=1).to(torch.int8).to(
self.device)
self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
@@ -4499,18 +4479,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
tail_attn_nomask_seqlens = torch.tensor(
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
dtype=torch.int32)
if self.vllm_config.model_config.use_mla:
pcp_prefill_mask = torch.triu(
torch.ones(512,
512,
device=self.device,
dtype=self.dtype), 1)
else:
pcp_prefill_mask = torch.triu(
torch.full((2048, 2048),
True,
device=self.device,
dtype=torch.bool), 1)
pcp_prefill_mask = self.attn_mask
self.extra_long_seq_kwargs = {
'attn_mask_seqlens': attn_mask_seqlens,