### What this PR does / why we need it?
This PR optimizes the `_compute_slot_mappings_kernel` for Ascend NPUs to
improve performance. The key changes include:
- A new Triton kernel implementation (`_compute_slot_mappings_kernel`)
with NPU-specific optimizations, such as using `tl.gather` to handle
non-contiguous memory access and replacing modulo operations.
- A new method `compute_slot_mappings` in `AscendBlockTables` to use
this new kernel.
- An end-to-end test to verify the correctness of the new kernel against
the reference GPU implementation.
The optimization is needed to avoid performance degradation from scalar
computation on Ascend devices.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.18.0
- vLLM main:
ed359c497a
---------
Signed-off-by: lhp-deep <liuhaopeng1@huawei.com>
162 lines
6.2 KiB
Python
162 lines
6.2 KiB
Python
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/block_table.py
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
|
from vllm.v1.worker.gpu.block_table import BlockTables, _load_ptr
|
|
|
|
|
|
class AscendBlockTables(BlockTables):
|
|
"""Block table for Ascend NPUs."""
|
|
|
|
def __init__(
|
|
self,
|
|
block_sizes: list[int],
|
|
max_num_reqs: int,
|
|
max_num_batched_tokens: int,
|
|
max_model_len: int,
|
|
device: torch.device,
|
|
cp_size: int = 1,
|
|
cp_rank: int = 0,
|
|
cp_interleave: int = 1,
|
|
):
|
|
super().__init__(
|
|
block_sizes,
|
|
max_num_reqs,
|
|
max_num_batched_tokens,
|
|
max_model_len,
|
|
device,
|
|
cp_size,
|
|
cp_rank,
|
|
cp_interleave,
|
|
)
|
|
# because we will override these attribute, delete these attribute to
|
|
# make sure it's collected by python gc immediately.
|
|
del self.slot_mappings
|
|
# vllm-ascend' reshape_and_cache function requires slot_mappings to be int32.
|
|
# so we need to redefine slot_mappings to be int32.
|
|
self.slot_mappings: torch.Tensor = torch.zeros(
|
|
self.num_kv_cache_groups,
|
|
self.max_num_batched_tokens,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
|
|
def compute_slot_mappings(
|
|
self,
|
|
idx_mapping: torch.Tensor,
|
|
query_start_loc: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
num_tokens_padded: int,
|
|
) -> torch.Tensor:
|
|
num_reqs = idx_mapping.shape[0]
|
|
num_groups = self.num_kv_cache_groups
|
|
_compute_slot_mappings_kernel[(num_groups, num_reqs + 1)](
|
|
self.max_num_batched_tokens,
|
|
idx_mapping,
|
|
query_start_loc,
|
|
positions,
|
|
self.block_table_ptrs,
|
|
self.block_table_strides,
|
|
self.block_sizes_tensor,
|
|
self.slot_mappings,
|
|
self.slot_mappings.stride(0),
|
|
self.cp_rank,
|
|
CP_SIZE=self.cp_size,
|
|
CP_INTERLEAVE=self.cp_interleave,
|
|
PAD_ID=PAD_SLOT_ID,
|
|
TRITON_BLOCK_SIZE=1024, # type: ignore
|
|
TOTAL_BLOCK_SIZE=4096,
|
|
)
|
|
return self.slot_mappings[:, :num_tokens_padded]
|
|
|
|
|
|
@triton.jit
|
|
def _compute_slot_mappings_kernel(
|
|
max_num_tokens,
|
|
idx_mapping, # [num_reqs]
|
|
query_start_loc, # [num_reqs + 1]
|
|
pos, # [num_tokens]
|
|
block_table_ptrs, # [num_kv_cache_groups]
|
|
block_table_strides, # [num_kv_cache_groups]
|
|
block_sizes, # [num_kv_cache_groups]
|
|
slot_mappings_ptr, # [num_kv_cache_groups, max_num_tokens]
|
|
slot_mappings_stride,
|
|
cp_rank,
|
|
CP_SIZE: tl.constexpr,
|
|
CP_INTERLEAVE: tl.constexpr,
|
|
PAD_ID: tl.constexpr,
|
|
TRITON_BLOCK_SIZE: tl.constexpr,
|
|
TOTAL_BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
# kv cache group id
|
|
group_id = tl.program_id(0)
|
|
batch_idx = tl.program_id(1)
|
|
slot_mapping_ptr = slot_mappings_ptr + group_id * slot_mappings_stride
|
|
|
|
if batch_idx == tl.num_programs(1) - 1:
|
|
actual_num_tokens = tl.load(query_start_loc + batch_idx)
|
|
for i in range(actual_num_tokens, max_num_tokens, TRITON_BLOCK_SIZE):
|
|
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
|
tl.store(slot_mapping_ptr + offset, PAD_ID, mask=offset < max_num_tokens)
|
|
return
|
|
|
|
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
|
block_table_stride = tl.load(block_table_strides + group_id)
|
|
block_size = tl.load(block_sizes + group_id)
|
|
|
|
req_state_idx = tl.load(idx_mapping + batch_idx)
|
|
start_idx = tl.load(query_start_loc + batch_idx)
|
|
end_idx = tl.load(query_start_loc + batch_idx + 1)
|
|
for i in range(start_idx, end_idx, TRITON_BLOCK_SIZE):
|
|
offset = i + tl.arange(0, TRITON_BLOCK_SIZE)
|
|
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
|
|
|
# Type conversion of 'position' to int32 to be compatible with npu
|
|
# otherwise, it will degrade to scalar computation
|
|
positions = positions.to(tl.int32)
|
|
block_indices = positions // (block_size * CP_SIZE)
|
|
|
|
# block_offset = positions % (block_size * CP_SIZE)
|
|
# The % operation on int32 type will degrade to scalar computation
|
|
# replace the % operation with sub and mul instead
|
|
block_offsets = positions - (block_size * CP_SIZE) * block_indices
|
|
|
|
# The 'block_indics' variable results in non-contiguous memory assess,
|
|
# which triggers degradation toscalar computation.
|
|
# Mitigate this by loading the complete data block and extracting the required data with tl.gather
|
|
block_numbers = tl.load(block_table_ptr + req_state_idx * block_table_stride + tl.arange(0, TOTAL_BLOCK_SIZE))
|
|
block_numbers = block_numbers.to(tl.float32)
|
|
block_numbers = tl.gather(block_numbers, block_indices, 0)
|
|
|
|
if CP_SIZE == 1:
|
|
# Common case: Context parallelism is not used.
|
|
slot_ids = block_numbers * block_size + block_offsets
|
|
else:
|
|
# Context parallelism is used.
|
|
is_local = block_offsets // CP_INTERLEAVE % CP_SIZE == cp_rank
|
|
rounds = block_offsets // (CP_INTERLEAVE * CP_SIZE)
|
|
remainder = block_offsets % CP_INTERLEAVE
|
|
local_offsets = rounds * CP_INTERLEAVE + remainder
|
|
slot_ids = block_numbers * block_size + local_offsets
|
|
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
|
|
|
|
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|