Files
xc-llm-ascend/vllm_ascend/attention/attention_mask.py

96 lines
4.2 KiB
Python
Raw Permalink Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from vllm.distributed import get_pcp_group
from vllm_ascend.platform import ModelConfig
from vllm_ascend.utils import singleton
def _generate_attn_mask(max_seq_len, dtype):
# Construct lower triangle matrix.
mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_()
# Create upper triangle matrix used to mark mask positions.
mask_flag = ~mask_flag
# Currently for fp16 dtype, the mask value should be set to -inf.
# TODO: Eliminate this part in the future.
mask_value = float("-inf") if dtype == torch.float16 else 1
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype).masked_fill_(mask_flag, mask_value)
return attn_mask
@singleton
class AttentionMaskBuilder:
[Refactor] 2/N Unify all mask generation methods and cache mask (#4779) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: ZYang6263 <zy626375@gmail.com> Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun <dsxsteven@sina.com> Signed-off-by: lulina <lina.lulina@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: lhp-deep <liuhaopeng1@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina <lina.lulina@huawei.com> Co-authored-by: zengzengran <zengran2@huawei.com> Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: lhp-deep <liuhaopeng1@huawei.com> Co-authored-by: Canlin Guo <canlinguosdu@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com>
2025-12-09 18:51:00 +08:00
def __init__(self, device: torch.device):
self.attn_mask_cache = None
self._seq_len_cached = 0
[Perf] Add new npu_fused_infer_attention_score op to improve perfomance in splitfuse cases and resolve long-seq mask problems (#2962) ### What this PR does / why we need it? Add new npu_fused_infer_attention_score op to improve perfomance in splitfuse cases and resolve long-seq mask problems . 1. The original op's performance is suboptimal in certain scenarios, necessitating optimization through the _new op_ (npu_fused_infer_attention_score)。 2. For ultra-long sequences (128k), the original operator will allocate a large attn_mask, which consumes excessive CPU memory. In contrast, the _new op_ supports a fixed-size compressed mask, effectively resolving this issue. NOTE1: The current PR retains the original logic and uses a version check of the CANN package to determine whether the _new op_ can be enabled. This ensures no impact on existing users. In future versions, this version check and the original logic will be deprecated, and the _new op_ scheduling will be uniformly adopted. NOTE2: This pr relies on future CANN version, which is not available now. NOTE3: To enable the new op in chunked prefill, the parameter additional_config should be set like `--additional-config '{"ascend_scheduler_config": {"enabled":true,"enable_chunked_prefill":true}}' \` at least. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/6c5f82e5aa87cd73ce03ce10fc44138f75ee1aea --------- Signed-off-by: tangtianyi <tangtianyi4@huawei.com> Signed-off-by: Angazenn <supperccell@163.com> Co-authored-by: Angazenn <supperccell@163.com>
2025-09-22 14:56:14 +08:00
self.device = device
[Refactor] 2/N Unify all mask generation methods and cache mask (#4779) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: ZYang6263 <zy626375@gmail.com> Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun <dsxsteven@sina.com> Signed-off-by: lulina <lina.lulina@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: lhp-deep <liuhaopeng1@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina <lina.lulina@huawei.com> Co-authored-by: zengzengran <zengran2@huawei.com> Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: lhp-deep <liuhaopeng1@huawei.com> Co-authored-by: Canlin Guo <canlinguosdu@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com>
2025-12-09 18:51:00 +08:00
self.mla_mask = None
self.chunked_prefill_attn_mask = None
self.pcp_mla_mask = None
self.swa_mask = None
[Refactor] 2/N Unify all mask generation methods and cache mask (#4779) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: ZYang6263 <zy626375@gmail.com> Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun <dsxsteven@sina.com> Signed-off-by: lulina <lina.lulina@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: lhp-deep <liuhaopeng1@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina <lina.lulina@huawei.com> Co-authored-by: zengzengran <zengran2@huawei.com> Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: lhp-deep <liuhaopeng1@huawei.com> Co-authored-by: Canlin Guo <canlinguosdu@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com>
2025-12-09 18:51:00 +08:00
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype):
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
self.attn_mask_cache = _generate_attn_mask(max_seq_len, dtype)
self._seq_len_cached = max_seq_len
assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask."
if self.attn_mask_cache.dtype != dtype:
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous().to(self.device, non_blocking=True)
[Refactor] 2/N Unify all mask generation methods and cache mask (#4779) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: ZYang6263 <zy626375@gmail.com> Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun <dsxsteven@sina.com> Signed-off-by: lulina <lina.lulina@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: lhp-deep <liuhaopeng1@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina <lina.lulina@huawei.com> Co-authored-by: zengzengran <zengran2@huawei.com> Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: lhp-deep <liuhaopeng1@huawei.com> Co-authored-by: Canlin Guo <canlinguosdu@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com>
2025-12-09 18:51:00 +08:00
def get_splitfuse_attn_mask(self) -> torch.Tensor:
if self.chunked_prefill_attn_mask is None:
self.chunked_prefill_attn_mask = (
torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8).to(self.device)
)
return self.chunked_prefill_attn_mask
[Refactor] 2/N Unify all mask generation methods and cache mask (#4779) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: ZYang6263 <zy626375@gmail.com> Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun <dsxsteven@sina.com> Signed-off-by: lulina <lina.lulina@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: lhp-deep <liuhaopeng1@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina <lina.lulina@huawei.com> Co-authored-by: zengzengran <zengran2@huawei.com> Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: lhp-deep <liuhaopeng1@huawei.com> Co-authored-by: Canlin Guo <canlinguosdu@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com>
2025-12-09 18:51:00 +08:00
def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor:
if self.mla_mask is None or self.mla_mask.dtype != dtype:
if dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
prefill_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
self.mla_mask = torch.where(prefill_mask == 1, mask_value, 0).to(dtype)
[Refactor] 2/N Unify all mask generation methods and cache mask (#4779) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: ZYang6263 <zy626375@gmail.com> Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun <dsxsteven@sina.com> Signed-off-by: lulina <lina.lulina@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: lhp-deep <liuhaopeng1@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina <lina.lulina@huawei.com> Co-authored-by: zengzengran <zengran2@huawei.com> Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: lhp-deep <liuhaopeng1@huawei.com> Co-authored-by: Canlin Guo <canlinguosdu@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com>
2025-12-09 18:51:00 +08:00
return self.mla_mask
def get_pcp_mla_mask(self, dtype: torch.dtype):
if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype:
self.pcp_mla_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
[Refactor] 2/N Unify all mask generation methods and cache mask (#4779) RFC: https://github.com/vllm-project/vllm-ascend/issues/4629 Reason: There are various types of masks here, and some of them do not have a caching mechanism. As a result, the masks need to be initialized for each layer, leading to waste of video memory. At the same time, we hope to standardize the management and usage of masks. So we have gathered all the masks into the AttentionMaskBuilder class. Todo: 1. remove spec_attn_mask; @LICO1314 2. remove pcp_prefill_mask; @LICO1314 - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Signed-off-by: ZYang6263 <zy626375@gmail.com> Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Signed-off-by: daishixun <dsxsteven@sina.com> Signed-off-by: lulina <lina.lulina@huawei.com> Signed-off-by: zengran <zengran2@huawei.com> Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com> Signed-off-by: dependabot[bot] <support@github.com> Signed-off-by: 李少鹏 <lishaopeng21@huawei.com> Signed-off-by: xuyexiong <xuyexiong@huawei.com> Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: lhp-deep <liuhaopeng1@huawei.com> Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: wangli <wangli858794774@gmail.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com> Co-authored-by: LuLina <lina.lulina@huawei.com> Co-authored-by: zengzengran <zengran2@huawei.com> Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com> Co-authored-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: lhp-deep <liuhaopeng1@huawei.com> Co-authored-by: Canlin Guo <canlinguosdu@gmail.com> Co-authored-by: Li Wang <wangli858794774@gmail.com>
2025-12-09 18:51:00 +08:00
return self.pcp_mla_mask
def get_swa_mask(self, dtype: torch.dtype, sliding_window):
if self.swa_mask is None or self.swa_mask.dtype != dtype:
if sliding_window is not None:
mask = torch.ones(2048, 2048, dtype=torch.bool)
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
self.swa_mask = triu_mask + tril_mask
return self.swa_mask
def get_attention_mask(self, model_config: ModelConfig):
if model_config.runner_type == "pooling":
return self.get_attn_mask(2048, torch.bool)
return self.get_splitfuse_attn_mask()
def get_final_mla_mask(self, model_config: ModelConfig):
if get_pcp_group().world_size > 1:
return self.get_pcp_mla_mask(model_config.dtype)
# Prefill stages use 512x512 mask with appropriate dtype
return self.get_mla_mask(model_config.dtype)