diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 079efff..b1da723 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -17,19 +17,15 @@ import torch def _generate_attn_mask(max_seq_len, dtype): # Construct lower triangle matrix. - mask_flag = torch.tril( - torch.ones((max_seq_len, max_seq_len), - dtype=torch.bool)).view(max_seq_len, max_seq_len) + mask_flag = torch.ones((max_seq_len, max_seq_len), + dtype=torch.bool).tril_() # Create upper triangle matrix used to mark mask positions. mask_flag = ~mask_flag # Currently for fp16 dtype, the mask value should be set to -inf. # TODO: Eliminate this part in the future. - if dtype == torch.float16: - mask_value = torch.finfo(torch.float32).min - else: - mask_value = 1 - attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), - mask_flag, mask_value).to(dtype) + mask_value = float('-inf') if dtype == torch.float16 else 1 + attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \ + .masked_fill_(mask_flag, mask_value) return attn_mask