2025-07-09 09:12:03 +08:00
|
|
|
#
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import torch
|
2026-01-07 17:09:52 +08:00
|
|
|
from vllm.distributed import get_pcp_group
|
|
|
|
|
|
|
|
|
|
from vllm_ascend.platform import ModelConfig
|
|
|
|
|
from vllm_ascend.utils import singleton
|
2025-07-09 09:12:03 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_attn_mask(max_seq_len, dtype):
|
|
|
|
|
# Construct lower triangle matrix.
|
2026-01-19 08:59:46 +08:00
|
|
|
mask_flag = torch.ones((max_seq_len, max_seq_len), dtype=torch.bool).tril_()
|
2025-07-09 09:12:03 +08:00
|
|
|
# Create upper triangle matrix used to mark mask positions.
|
|
|
|
|
mask_flag = ~mask_flag
|
|
|
|
|
# Currently for fp16 dtype, the mask value should be set to -inf.
|
|
|
|
|
# TODO: Eliminate this part in the future.
|
2026-01-19 08:59:46 +08:00
|
|
|
mask_value = float("-inf") if dtype == torch.float16 else 1
|
|
|
|
|
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype).masked_fill_(mask_flag, mask_value)
|
2025-07-09 09:12:03 +08:00
|
|
|
return attn_mask
|
|
|
|
|
|
|
|
|
|
|
2026-01-07 17:09:52 +08:00
|
|
|
@singleton
|
2025-07-09 09:12:03 +08:00
|
|
|
class AttentionMaskBuilder:
|
2025-12-09 18:51:00 +08:00
|
|
|
def __init__(self, device: torch.device):
|
|
|
|
|
self.attn_mask_cache = None
|
|
|
|
|
self._seq_len_cached = 0
|
[Perf] Add new npu_fused_infer_attention_score op to improve perfomance in splitfuse cases and resolve long-seq mask problems (#2962)
### What this PR does / why we need it?
Add new npu_fused_infer_attention_score op to improve perfomance in
splitfuse cases and resolve long-seq mask problems .
1. The original op's performance is suboptimal in certain scenarios,
necessitating optimization through the _new op_
(npu_fused_infer_attention_score)。
2. For ultra-long sequences (128k), the original operator will allocate
a large attn_mask, which consumes excessive CPU memory. In contrast, the
_new op_ supports a fixed-size compressed mask, effectively resolving
this issue.
NOTE1: The current PR retains the original logic and uses a version
check of the CANN package to determine whether the _new op_ can be
enabled. This ensures no impact on existing users. In future versions,
this version check and the original logic will be deprecated, and the
_new op_ scheduling will be uniformly adopted.
NOTE2: This pr relies on future CANN version, which is not available
now.
NOTE3: To enable the new op in chunked prefill, the parameter
additional_config should be set like `--additional-config
'{"ascend_scheduler_config":
{"enabled":true,"enable_chunked_prefill":true}}' \` at least.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passed
- vLLM version: v0.10.2
- vLLM main:
https://github.com/vllm-project/vllm/commit/6c5f82e5aa87cd73ce03ce10fc44138f75ee1aea
---------
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
2025-09-22 14:56:14 +08:00
|
|
|
self.device = device
|
2025-12-09 18:51:00 +08:00
|
|
|
self.mla_mask = None
|
|
|
|
|
self.chunked_prefill_attn_mask = None
|
|
|
|
|
self.pcp_mla_mask = None
|
2025-12-29 14:56:25 +08:00
|
|
|
self.swa_mask = None
|
2025-08-27 12:07:29 +08:00
|
|
|
|
2025-12-09 18:51:00 +08:00
|
|
|
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype):
|
|
|
|
|
if self.attn_mask_cache is None or max_seq_len > self._seq_len_cached:
|
|
|
|
|
self.attn_mask_cache = _generate_attn_mask(max_seq_len, dtype)
|
|
|
|
|
self._seq_len_cached = max_seq_len
|
|
|
|
|
assert self.attn_mask_cache is not None, "Something is wrong in generate_attn_mask."
|
|
|
|
|
if self.attn_mask_cache.dtype != dtype:
|
|
|
|
|
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
|
2026-01-19 08:59:46 +08:00
|
|
|
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous().to(self.device, non_blocking=True)
|
2025-07-09 09:12:03 +08:00
|
|
|
|
2025-12-09 18:51:00 +08:00
|
|
|
def get_splitfuse_attn_mask(self) -> torch.Tensor:
|
|
|
|
|
if self.chunked_prefill_attn_mask is None:
|
2026-01-19 08:59:46 +08:00
|
|
|
self.chunked_prefill_attn_mask = (
|
|
|
|
|
torch.triu(torch.ones(2048, 2048), diagonal=1).to(torch.int8).to(self.device)
|
|
|
|
|
)
|
2025-11-03 20:21:07 +08:00
|
|
|
return self.chunked_prefill_attn_mask
|
2025-07-09 09:12:03 +08:00
|
|
|
|
2025-12-09 18:51:00 +08:00
|
|
|
def get_mla_mask(self, dtype: torch.dtype) -> torch.Tensor:
|
|
|
|
|
if self.mla_mask is None or self.mla_mask.dtype != dtype:
|
|
|
|
|
if dtype == torch.float16:
|
|
|
|
|
mask_value = torch.finfo(torch.float32).min
|
|
|
|
|
else:
|
|
|
|
|
mask_value = 1
|
2026-01-19 08:59:46 +08:00
|
|
|
prefill_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
|
|
|
|
|
self.mla_mask = torch.where(prefill_mask == 1, mask_value, 0).to(dtype)
|
2025-12-09 18:51:00 +08:00
|
|
|
return self.mla_mask
|
|
|
|
|
|
|
|
|
|
def get_pcp_mla_mask(self, dtype: torch.dtype):
|
|
|
|
|
if self.pcp_mla_mask is None or self.pcp_mla_mask.dtype != dtype:
|
2026-01-19 08:59:46 +08:00
|
|
|
self.pcp_mla_mask = torch.triu(torch.ones(512, 512, device=self.device, dtype=dtype), 1)
|
2025-12-09 18:51:00 +08:00
|
|
|
return self.pcp_mla_mask
|
2025-12-29 14:56:25 +08:00
|
|
|
|
|
|
|
|
def get_swa_mask(self, dtype: torch.dtype, sliding_window):
|
|
|
|
|
if self.swa_mask is None or self.swa_mask.dtype != dtype:
|
|
|
|
|
if sliding_window is not None:
|
|
|
|
|
mask = torch.ones(2048, 2048, dtype=torch.bool)
|
|
|
|
|
triu_mask = torch.triu(mask, diagonal=1).to(self.device)
|
|
|
|
|
tril_mask = torch.tril(mask, -sliding_window).to(self.device)
|
|
|
|
|
self.swa_mask = triu_mask + tril_mask
|
2026-01-07 17:09:52 +08:00
|
|
|
return self.swa_mask
|
|
|
|
|
|
|
|
|
|
def get_attention_mask(self, model_config: ModelConfig):
|
|
|
|
|
if model_config.runner_type == "pooling":
|
|
|
|
|
return self.get_attn_mask(2048, torch.bool)
|
|
|
|
|
|
|
|
|
|
return self.get_splitfuse_attn_mask()
|
|
|
|
|
|
|
|
|
|
def get_final_mla_mask(self, model_config: ModelConfig):
|
|
|
|
|
if get_pcp_group().world_size > 1:
|
|
|
|
|
return self.get_pcp_mla_mask(model_config.dtype)
|
|
|
|
|
# Prefill stages use 512x512 mask with appropriate dtype
|
2026-01-19 08:59:46 +08:00
|
|
|
return self.get_mla_mask(model_config.dtype)
|