[Refactor] 2/N Unify all mask generation methods and cache mask (#4779)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
There are various types of masks here, and some of them do not have a
caching mechanism. As a result, the masks need to be initialized for
each layer, leading to waste of video memory.
At the same time, we hope to standardize the management and usage of
masks.
So we have gathered all the masks into the AttentionMaskBuilder class.
Todo:
1. remove spec_attn_mask; @LICO1314
2. remove pcp_prefill_mask; @LICO1314
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Signed-off-by: daishixun <dsxsteven@sina.com>
Signed-off-by: lulina <lina.lulina@huawei.com>
Signed-off-by: zengran <zengran2@huawei.com>
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: 李少鹏 <lishaopeng21@huawei.com>
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com>
Co-authored-by: LuLina <lina.lulina@huawei.com>
Co-authored-by: zengzengran <zengran2@huawei.com>
Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com>
Co-authored-by: xuyexiong <xuyexiong@huawei.com>
Co-authored-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: Canlin Guo <canlinguosdu@gmail.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
This commit is contained in:
@@ -378,12 +378,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self.block_size,
|
||||
use_mla=self.model_config.use_mla,
|
||||
use_sparse=self.use_sparse)
|
||||
if self.pcp_size > 1:
|
||||
self.attn_mask_builder = None
|
||||
else:
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
||||
self.device)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
self._set_up_drafter()
|
||||
|
||||
@@ -651,10 +646,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
assert spec_token_num > 0
|
||||
self.decode_token_per_req = 1 + spec_token_num
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
2048,
|
||||
dtype=torch.bool),
|
||||
diagonal=1).to(self.device)
|
||||
self.spec_attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = self._get_drafter()
|
||||
self.rejection_sampler = AscendRejectionSampler(self.sampler)
|
||||
@@ -1033,21 +1026,20 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
|
||||
return tuple(tasks)
|
||||
|
||||
def _make_attention_mask(self, seq_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
def _make_attention_mask(self, attn_state) -> torch.Tensor:
|
||||
# pcp situation.
|
||||
if self.pcp_size > 1:
|
||||
return None
|
||||
if self.attn_mask_builder is None:
|
||||
raise ValueError("Attn mask builder is None")
|
||||
# dcp situation.
|
||||
if self.dcp_size > 1:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
return None
|
||||
# Pooling situation.
|
||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||
return self.attn_mask_builder.get_pooling_mask()
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.pcp_size > 1:
|
||||
return self.attn_mask_builder.get_pcp_mla_mask(self.dtype)
|
||||
# mla prefill
|
||||
if attn_state != AscendAttentionState.DecodeOnly:
|
||||
return self.attn_mask_builder.get_mla_mask(self.dtype)
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||
@@ -1668,16 +1660,9 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self.positions[:num_input_tokens].copy_(
|
||||
self.positions_cpu[:num_input_tokens], non_blocking=True)
|
||||
|
||||
# Make Attention metadata
|
||||
positions_cpu = self.positions_cpu[:num_input_tokens]
|
||||
positions = self.positions[:num_input_tokens]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||
position=positions_cpu,
|
||||
attn_state=attn_state)
|
||||
self.attn_mask = self._make_attention_mask(attn_state)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
self.with_prefill = with_prefill
|
||||
@@ -2840,12 +2825,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
self.query_start_loc_cpu[1:num_reqs +
|
||||
1] = torch.Tensor(cu_num_tokens)
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
assigned_mask_dim = 2048
|
||||
self.attn_mask = torch.triu(torch.ones(assigned_mask_dim,
|
||||
assigned_mask_dim),
|
||||
diagonal=1).to(torch.int8).to(
|
||||
self.device)
|
||||
self.attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
@@ -4499,18 +4479,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
tail_attn_nomask_seqlens = torch.tensor(
|
||||
[chunk_seqlens, kv_with_q_tail_nomask_seqlens],
|
||||
dtype=torch.int32)
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
pcp_prefill_mask = torch.triu(
|
||||
torch.ones(512,
|
||||
512,
|
||||
device=self.device,
|
||||
dtype=self.dtype), 1)
|
||||
else:
|
||||
pcp_prefill_mask = torch.triu(
|
||||
torch.full((2048, 2048),
|
||||
True,
|
||||
device=self.device,
|
||||
dtype=torch.bool), 1)
|
||||
pcp_prefill_mask = self.attn_mask
|
||||
|
||||
self.extra_long_seq_kwargs = {
|
||||
'attn_mask_seqlens': attn_mask_seqlens,
|
||||
|
||||
Reference in New Issue
Block a user