[Refactor] 2/N Unify all mask generation methods and cache mask (#4779)
RFC: https://github.com/vllm-project/vllm-ascend/issues/4629
Reason:
There are various types of masks here, and some of them do not have a
caching mechanism. As a result, the masks need to be initialized for
each layer, leading to waste of video memory.
At the same time, we hope to standardize the management and usage of
masks.
So we have gathered all the masks into the AttentionMaskBuilder class.
Todo:
1. remove spec_attn_mask; @LICO1314
2. remove pcp_prefill_mask; @LICO1314
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Signed-off-by: daishixun <dsxsteven@sina.com>
Signed-off-by: lulina <lina.lulina@huawei.com>
Signed-off-by: zengran <zengran2@huawei.com>
Signed-off-by: shiro-zzzz <zhangdianhao@huawei.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: 李少鹏 <lishaopeng21@huawei.com>
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: dsxsteven <36877507+dsxsteven@users.noreply.github.com>
Co-authored-by: LuLina <lina.lulina@huawei.com>
Co-authored-by: zengzengran <zengran2@huawei.com>
Co-authored-by: shiro-zzzz <zhangdianhao@huawei.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: shaopeng-666 <lishaopeng21@huawei.com>
Co-authored-by: xuyexiong <xuyexiong@huawei.com>
Co-authored-by: lhp-deep <liuhaopeng1@huawei.com>
Co-authored-by: Canlin Guo <canlinguosdu@gmail.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
This commit is contained in:
@@ -31,66 +31,54 @@ def _generate_attn_mask(max_seq_len, dtype):
|
||||
|
||||
class AttentionMaskBuilder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device = None,
|
||||
):
|
||||
# NOTE: The device argument specifies the target NPU
|
||||
# to be used for the newly added FIA operator.
|
||||
# Only pass this parameter when using the new FIA operator.
|
||||
|
||||
attn_mask = _generate_attn_mask(max_seq_len, dtype)
|
||||
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
def __init__(self, device: torch.device):
|
||||
self.attn_mask_cache = None
|
||||
self._seq_len_cached = 0
|
||||
self.device = device
|
||||
self.pooling_mask = None
|
||||
assigned_mask_dim = 2048
|
||||
self.chunked_prefill_attn_mask = torch.triu(
|
||||
torch.ones(assigned_mask_dim, assigned_mask_dim),
|
||||
diagonal=1).to(torch.int8).to(device)
|
||||
self.mla_mask = None
|
||||
self.chunked_prefill_attn_mask = None
|
||||
self.pcp_mla_mask = None
|
||||
|
||||
@staticmethod
|
||||
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
|
||||
if dtype == torch.float16:
|
||||
mask_scale_factor = 1
|
||||
elif dtype == torch.bfloat16:
|
||||
mask_scale_factor = -10000
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current operation now only supports data types: torch.float16 and "
|
||||
"torch.bfloat16. Please ensure the input is of one of these types."
|
||||
)
|
||||
return mask_scale_factor
|
||||
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
||||
device: torch.device):
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype):
|
||||
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
|
||||
self.attn_mask_cache = _generate_attn_mask(max_seq_len, dtype)
|
||||
self._seq_len_cached = max_seq_len
|
||||
assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask."
|
||||
if self.attn_mask_cache.dtype != dtype:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device, non_blocking=True)
|
||||
).to(self.device, non_blocking=True)
|
||||
|
||||
def get_pooling_mask(self, device):
|
||||
def get_pooling_mask(self):
|
||||
if self.pooling_mask is None:
|
||||
# the compressed attention mask for npu_fusion_attention sparse mode 4
|
||||
self.pooling_mask = torch.triu(torch.ones(
|
||||
2048, 2048), diagonal=1).to(torch.bool).to(device,
|
||||
2048, 2048), diagonal=1).to(torch.bool).to(self.device,
|
||||
non_blocking=True)
|
||||
return self.pooling_mask
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens: torch.Tensor = None,
|
||||
position: torch.Tensor = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
) -> torch.Tensor:
|
||||
def get_splitfuse_attn_mask(self) -> torch.Tensor:
|
||||
if self.chunked_prefill_attn_mask is None:
|
||||
self.chunked_prefill_attn_mask = torch.triu(
|
||||
torch.ones(2048,
|
||||
2048), diagonal=1).to(torch.int8).to(self.device)
|
||||
return self.chunked_prefill_attn_mask
|
||||
|
||||
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
|
||||
if seqlen > self._seq_len_cached:
|
||||
self._seq_len_cached = seqlen
|
||||
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
|
||||
if self.attn_mask_cache.dtype != dtype:
|
||||
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
|
||||
def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor:
|
||||
if self.mla_mask is None or self.mla_mask.dtype != dtype:
|
||||
if dtype == torch.float16:
|
||||
mask_value = torch.finfo(torch.float32).min
|
||||
else:
|
||||
mask_value = 1
|
||||
prefill_mask = torch.triu(
|
||||
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
|
||||
self.mla_mask = torch.where(prefill_mask == 1, mask_value,
|
||||
0).to(dtype)
|
||||
return self.mla_mask
|
||||
|
||||
def get_pcp_mla_mask(self, dtype: torch.dtype):
|
||||
if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype:
|
||||
self.pcp_mla_mask = torch.triu(
|
||||
torch.ones(512, 512, device=self.device, dtype=dtype), 1)
|
||||
return self.pcp_mla_mask
|
||||
|
||||
Reference in New Issue
Block a user