### What this PR does / why we need it?
support prefill cache mode use fia op for full graph
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main:
17c540a993
origin
============ Serving Benchmark Result ============
Successful requests: 30
Maximum request concurrency: 256
Request rate configured (RPS): 0.70
Benchmark duration (s): 131.63
Total input tokens: 61363
Total generated tokens: 61440
Request throughput (req/s): 0.23
Output token throughput (tok/s): 466.77
Peak output token throughput (tok/s): 750.00
Peak concurrent requests: 30.00
Total Token throughput (tok/s): 932.95
---------------Time to First Token----------------
Mean TTFT (ms): 125.17
Median TTFT (ms): 121.51
P50 TTFT (ms): 121.51
P90 TTFT (ms): 140.91
P99 TTFT (ms): 182.36
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 43.85
Median TPOT (ms): 43.84
P50 TPOT (ms): 43.84
P90 TPOT (ms): 44.28
P99 TPOT (ms): 44.32
---------------Inter-token Latency----------------
Mean ITL (ms): 43.85
Median ITL (ms): 42.63
P50 ITL (ms): 42.63
P90 ITL (ms): 48.74
P99 ITL (ms): 59.62
==================================================
after
============ Serving Benchmark Result ============
Successful requests: 30
Maximum request concurrency: 256
Request rate configured (RPS): 0.70
Benchmark duration (s): 130.10
Total input tokens: 61363
Total generated tokens: 61440
Request throughput (req/s): 0.23
Output token throughput (tok/s): 472.26
Peak output token throughput (tok/s): 750.00
Peak concurrent requests: 30.00
Total Token throughput (tok/s): 943.94
---------------Time to First Token----------------
Mean TTFT (ms): 123.69
Median TTFT (ms): 122.51
P50 TTFT (ms): 122.51
P90 TTFT (ms): 143.69
P99 TTFT (ms): 165.00
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 43.07
Median TPOT (ms): 43.13
P50 TPOT (ms): 43.13
P90 TPOT (ms): 43.50
P99 TPOT (ms): 43.57
---------------Inter-token Latency----------------
Mean ITL (ms): 43.07
Median ITL (ms): 41.81
P50 ITL (ms): 41.81
P90 ITL (ms): 48.11
P99 ITL (ms): 62.13
==================================================
Signed-off-by: shiyuan680 <917935075@qq.com>
116 lines
4.8 KiB
Python
116 lines
4.8 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import torch
|
|
|
|
|
|
def _generate_attn_mask(max_seq_len, dtype):
|
|
# Construct lower triangle matrix.
|
|
mask_flag = torch.ones((max_seq_len, max_seq_len),
|
|
dtype=torch.bool).tril_()
|
|
# Create upper triangle matrix used to mark mask positions.
|
|
mask_flag = ~mask_flag
|
|
# Currently for fp16 dtype, the mask value should be set to -inf.
|
|
# TODO: Eliminate this part in the future.
|
|
mask_value = float('-inf') if dtype == torch.float16 else 1
|
|
attn_mask = torch.zeros(size=(max_seq_len, max_seq_len), dtype=dtype) \
|
|
.masked_fill_(mask_flag, mask_value)
|
|
return attn_mask
|
|
|
|
|
|
class AttentionMaskBuilder:
|
|
|
|
def __init__(
|
|
self,
|
|
max_seq_len: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device = None,
|
|
):
|
|
# NOTE: The device argument specifies the target NPU
|
|
# to be used for the newly added FIA operator.
|
|
# Only pass this parameter when using the new FIA operator.
|
|
|
|
attn_mask = _generate_attn_mask(max_seq_len, dtype)
|
|
|
|
self._seq_len_cached = attn_mask.shape[0]
|
|
self.attn_mask_cache = attn_mask
|
|
self.device = device
|
|
self.pooling_mask = None
|
|
if torch.version.cann.startswith("8.3"):
|
|
assigned_mask_dim = 2048
|
|
self.chunked_prefill_attn_mask = torch.triu(
|
|
torch.ones(assigned_mask_dim, assigned_mask_dim),
|
|
diagonal=1).to(torch.int8).to(device)
|
|
|
|
@staticmethod
|
|
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
|
|
if dtype == torch.float16:
|
|
mask_scale_factor = 1
|
|
elif dtype == torch.bfloat16:
|
|
mask_scale_factor = -10000
|
|
else:
|
|
raise ValueError(
|
|
"The current operation now only supports data types: torch.float16 and "
|
|
"torch.bfloat16. Please ensure the input is of one of these types."
|
|
)
|
|
return mask_scale_factor
|
|
|
|
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
|
|
device: torch.device):
|
|
if max_seq_len == 2048 and torch.version.cann.startswith("8.3"):
|
|
return self.chunked_prefill_attn_mask.to(torch.bool)
|
|
self._update_attn_cache(max_seq_len, dtype)
|
|
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
|
).to(device, non_blocking=True)
|
|
|
|
def get_pooling_mask(self, device):
|
|
if self.pooling_mask is None:
|
|
# the compressed attention mask for npu_fusion_attention sparse mode 4
|
|
self.pooling_mask = torch.triu(torch.ones(
|
|
2048, 2048), diagonal=1).to(torch.bool).to(device,
|
|
non_blocking=True)
|
|
return self.pooling_mask
|
|
|
|
def get_splitfuse_attn_mask(
|
|
self,
|
|
seq_lens: torch.Tensor = None,
|
|
position: torch.Tensor = None,
|
|
dtype: torch.dtype = None,
|
|
device: torch.device = None,
|
|
) -> torch.Tensor:
|
|
if torch.version.cann.startswith("8.3"):
|
|
return self.chunked_prefill_attn_mask
|
|
else:
|
|
if dtype not in [torch.float16, torch.bfloat16]:
|
|
raise ValueError(
|
|
"splitfuse_attn_mask now only supports bf16 and fp16")
|
|
max_seq_len = max(seq_lens, default=0)
|
|
self._update_attn_cache(max_seq_len, dtype)
|
|
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
|
|
# is not the same. Fix this in the future when kernel is ready.
|
|
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(
|
|
dtype)
|
|
attn_mask = torch.index_select(self.attn_mask_cache,
|
|
dim=0,
|
|
index=position)[:, :max_seq_len]
|
|
attn_mask *= mask_scale_factor
|
|
return attn_mask.contiguous().to(device, non_blocking=True)
|
|
|
|
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
|
|
if seqlen > self._seq_len_cached:
|
|
self._seq_len_cached = seqlen
|
|
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
|
|
if self.attn_mask_cache.dtype != dtype:
|
|
self.attn_mask_cache = self.attn_mask_cache.to(dtype)
|