### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](https://github.com/vllm-project/vllm/pull/30877) and inherits
changes to functions which are overridden in vLLM-Ascend.
Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.
### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Angazenn <supperccell@163.com>
321 lines
14 KiB
Python
321 lines
14 KiB
Python
import numpy as np
|
|
import torch
|
|
from vllm.distributed import get_dcp_group, get_pcp_group
|
|
from vllm.utils.math_utils import cdiv
|
|
from vllm.v1.utils import CpuGpuBuffer
|
|
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
|
|
|
|
|
class BlockTable:
|
|
def __init__(
|
|
self,
|
|
block_size: int,
|
|
max_num_reqs: int,
|
|
max_num_blocks_per_req: int,
|
|
max_num_batched_tokens: int,
|
|
pin_memory: bool,
|
|
device: torch.device,
|
|
kernel_sizes: list[int] | None = None,
|
|
cp_kv_cache_interleave_size: int = 1,
|
|
num_speculative_tokens: int = 0,
|
|
):
|
|
self.max_num_reqs = max_num_reqs
|
|
self.max_num_blocks_per_req = max_num_blocks_per_req
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
|
self.pin_memory = pin_memory
|
|
self.device = device
|
|
self.physical_block_size = block_size
|
|
|
|
try:
|
|
self.pcp_world_size = get_pcp_group().world_size
|
|
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_world_size > 1 else 0
|
|
self.dcp_world_size = get_dcp_group().world_size
|
|
self.dcp_rank = get_dcp_group().rank_in_group
|
|
except AssertionError:
|
|
# DCP might not be initialized in testing
|
|
self.dcp_world_size = 1
|
|
self.dcp_rank = 0
|
|
self.pcp_world_size = 1
|
|
self.pcp_rank = 0
|
|
|
|
# If kernel_sizes is None or [0], use physical block size (no splitting)
|
|
if kernel_sizes is None or kernel_sizes == [0]:
|
|
self.block_size = block_size
|
|
self.logical_block_size = block_size
|
|
self.blocks_per_phys_block = 1
|
|
self.use_hybrid_blocks = False
|
|
else:
|
|
# Find the first kernel size that divides physical_block_size evenly
|
|
selected_kernel_size = None
|
|
for kernel_size in kernel_sizes:
|
|
if kernel_size > 0 and self.physical_block_size % kernel_size == 0:
|
|
selected_kernel_size = kernel_size
|
|
break
|
|
|
|
if selected_kernel_size is None:
|
|
raise ValueError(
|
|
f"None of the kernel sizes {kernel_sizes} can divide "
|
|
f"physical block size {self.physical_block_size} evenly"
|
|
)
|
|
|
|
self.block_size = selected_kernel_size
|
|
self.logical_block_size = selected_kernel_size
|
|
self.blocks_per_phys_block = self.physical_block_size // self.logical_block_size
|
|
if self.blocks_per_phys_block > 1:
|
|
self.use_hybrid_blocks = True
|
|
else:
|
|
self.use_hybrid_blocks = False
|
|
|
|
if self.use_hybrid_blocks:
|
|
logical_table_size = max_num_blocks_per_req * self.blocks_per_phys_block
|
|
else:
|
|
logical_table_size = max_num_blocks_per_req
|
|
|
|
duplicate_size = 1
|
|
if self.pcp_world_size * self.dcp_world_size > 1:
|
|
duplicate_size += num_speculative_tokens
|
|
self.block_table = self._make_buffer(max_num_reqs * duplicate_size, logical_table_size, dtype=torch.int32)
|
|
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
|
self.slot_mapping = self._make_buffer(
|
|
self.max_num_batched_tokens + 2 * self.pcp_world_size * self.max_num_reqs, dtype=torch.int32
|
|
)
|
|
|
|
self.kernel_sizes = kernel_sizes
|
|
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
|
|
|
def append_row(
|
|
self,
|
|
block_ids,
|
|
row_idx: int,
|
|
) -> None:
|
|
if not block_ids:
|
|
return
|
|
block_ids = np.array(block_ids)
|
|
if self.use_hybrid_blocks:
|
|
block_ids = self._convert_physical_to_logical_blocks(block_ids)
|
|
|
|
num_blocks = len(block_ids)
|
|
start = self.num_blocks_per_row[row_idx]
|
|
|
|
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
|
|
self.num_blocks_per_row[row_idx] += num_blocks
|
|
|
|
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
|
self.num_blocks_per_row[row_idx] = 0
|
|
self.append_row(block_ids, row_idx)
|
|
|
|
def move_row(self, src: int, tgt: int) -> None:
|
|
num_blocks = self.num_blocks_per_row[src]
|
|
self.block_table.np[tgt, :num_blocks] = self.block_table.np[src, :num_blocks]
|
|
self.num_blocks_per_row[tgt] = num_blocks
|
|
|
|
def swap_row(self, src: int, tgt: int) -> None:
|
|
num_blocks_src = self.num_blocks_per_row[src]
|
|
num_blocks_tgt = self.num_blocks_per_row[tgt]
|
|
self.num_blocks_per_row[src] = num_blocks_tgt
|
|
self.num_blocks_per_row[tgt] = num_blocks_src
|
|
|
|
self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]]
|
|
|
|
def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None:
|
|
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
|
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
|
# where K is the max_num_blocks_per_req and the block size is 2.
|
|
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
|
# here because M (max_model_len) is not necessarily divisible by
|
|
# block_size.
|
|
|
|
if self.dcp_world_size * self.pcp_world_size > 1:
|
|
# Note(hc): The DCP implement store kvcache with an interleave
|
|
# style, the kvcache for the token whose token_idx is i is
|
|
# always stored on the GPU whose dcp_rank equals i % pcp_world_size:
|
|
|
|
# Use a "virtual block" which equals to world_size * block_size
|
|
# for block_table_indices calculation.
|
|
virtual_block_size = self.block_size * self.dcp_world_size * self.pcp_world_size
|
|
|
|
# IMPORTANT: In hybrid mode, positions are in logical block space,
|
|
# but we need to map them to the correct logical block table indices
|
|
logical_block_idx = positions // virtual_block_size
|
|
|
|
# Account for the expanded logical table
|
|
# (always needed with unified tensor)
|
|
# Each physical block is split into multiple logical blocks
|
|
# The logical table has been expanded to accommodate this
|
|
block_table_indices = (
|
|
req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx
|
|
)
|
|
|
|
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
|
# Use virtual_block_size for mask calculation, which marks local
|
|
# tokens.
|
|
virtual_block_offsets = positions % virtual_block_size
|
|
self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
|
|
mask = (
|
|
virtual_block_offsets // self.cp_kv_cache_interleave_size % (self.dcp_world_size * self.pcp_world_size)
|
|
== self.current_rank
|
|
)
|
|
# Calculate local block_offsets
|
|
block_offsets = (
|
|
virtual_block_offsets
|
|
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size)
|
|
* self.cp_kv_cache_interleave_size
|
|
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
|
|
)
|
|
# Calculate slot_mapping
|
|
slot_mapping = block_numbers * self.block_size + block_offsets
|
|
# Write final slots, use -1 for not-local
|
|
self.slot_mapping.np[: req_indices.shape[0]] = np.where(mask, slot_mapping, -1)
|
|
else:
|
|
assert self.kernel_sizes is not None
|
|
if self.block_size == self.kernel_sizes[0]:
|
|
# IMPORTANT: In hybrid mode, positions are in logical block space,
|
|
# but we need to map them to the correct logical block table indices
|
|
logical_block_idx = positions // self.block_size
|
|
|
|
# Account for the expanded logical table
|
|
# (always needed with unified tensor)
|
|
# Each physical block is split into multiple logical blocks
|
|
# The logical table has been expanded to accommodate this
|
|
block_table_indices = (
|
|
req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx
|
|
)
|
|
|
|
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
|
block_offsets = positions % self.block_size
|
|
np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping.np[: req_indices.shape[0]])
|
|
|
|
def commit_block_table(self, num_reqs: int) -> None:
|
|
self.block_table.copy_to_gpu(num_reqs)
|
|
|
|
def commit_slot_mapping(self, num_tokens: int) -> None:
|
|
self.slot_mapping.copy_to_gpu(num_tokens)
|
|
|
|
def clear(self) -> None:
|
|
self.block_table.fill_(0)
|
|
self.block_table.cpu.fill_(0)
|
|
|
|
def _convert_physical_to_logical_blocks(self, physical_blocks: np.ndarray) -> np.ndarray:
|
|
"""Convert physical block IDs to logical block IDs."""
|
|
if not self.use_hybrid_blocks:
|
|
return physical_blocks
|
|
|
|
# Create logical block IDs by splitting each physical block
|
|
logical_blocks: list[int] = []
|
|
for phys_block in physical_blocks:
|
|
# Convert physical block to multiple logical blocks
|
|
# Physical block 1 becomes logical blocks
|
|
# [1*split_ratio, 1*split_ratio+1, ...]
|
|
# But we need to account for the fact that block 0 is special
|
|
base_logical = phys_block * self.blocks_per_phys_block
|
|
logical_blocks.extend(range(base_logical, base_logical + self.blocks_per_phys_block))
|
|
|
|
return np.array(logical_blocks, dtype=np.int32)
|
|
|
|
def get_device_tensor(self) -> torch.Tensor:
|
|
"""Returns the device tensor of the block table."""
|
|
return self.block_table.gpu
|
|
|
|
def get_cpu_tensor(self) -> torch.Tensor:
|
|
"""Returns the CPU tensor of the block table."""
|
|
return self.block_table.cpu
|
|
|
|
def get_numpy_array(self) -> np.ndarray:
|
|
"""Returns the numpy array of the block table."""
|
|
return self.block_table.np
|
|
|
|
def _make_buffer(self, *size: int | torch.SymInt, dtype: torch.dtype) -> CpuGpuBuffer:
|
|
return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
|
|
|
|
|
|
class MultiGroupBlockTable:
|
|
"""The BlockTables for each KV cache group."""
|
|
|
|
def __init__(
|
|
self,
|
|
max_num_reqs: int,
|
|
max_model_len: int,
|
|
max_num_batched_tokens: int,
|
|
pin_memory: bool,
|
|
device: torch.device,
|
|
block_sizes: list[int],
|
|
num_speculative_tokens: int = 0,
|
|
max_num_blocks: list[int] | None = None,
|
|
kernel_sizes: list[list[int]] | None = None,
|
|
cp_kv_cache_interleave_size: int = 1,
|
|
) -> None:
|
|
if kernel_sizes is None:
|
|
kernel_sizes = [[0]] * len(block_sizes)
|
|
# Ensure kernel_sizes matches block_sizes length
|
|
elif len(kernel_sizes) == 1 and len(block_sizes) > 1:
|
|
kernel_sizes = kernel_sizes * len(block_sizes)
|
|
elif len(kernel_sizes) != len(block_sizes):
|
|
raise ValueError(
|
|
f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})"
|
|
)
|
|
|
|
if max_num_blocks is None:
|
|
# Note(hc): each dcp rank only store
|
|
# (max_model_len//dcp_world_size) tokens in kvcache,
|
|
# so the block_size which used for calc max_num_blocks_per_req
|
|
# must be multiplied by dcp_world_size.
|
|
total_cp_world_size = get_total_cp_world_size()
|
|
max_num_blocks = [cdiv(max_model_len, block_size * total_cp_world_size) for block_size in block_sizes]
|
|
|
|
if len(max_num_blocks) != len(block_sizes):
|
|
raise ValueError(
|
|
f"max_num_blocks length ({len(max_num_blocks)}) must match block_sizes length ({len(block_sizes)})"
|
|
)
|
|
|
|
# Use zip to pair block_sizes with kernel_sizes one-to-one
|
|
self.block_tables = [
|
|
BlockTable(
|
|
block_size,
|
|
max_num_reqs,
|
|
max_num_blocks_per_req,
|
|
max_num_batched_tokens,
|
|
pin_memory,
|
|
device,
|
|
kernel_size_list,
|
|
cp_kv_cache_interleave_size,
|
|
num_speculative_tokens,
|
|
)
|
|
for block_size, kernel_size_list, max_num_blocks_per_req in zip(block_sizes, kernel_sizes, max_num_blocks)
|
|
]
|
|
|
|
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
|
for i, block_table in enumerate(self.block_tables):
|
|
block_table.append_row(block_ids[i], row_idx)
|
|
|
|
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
|
for i, block_table in enumerate(self.block_tables):
|
|
block_table.add_row(block_ids[i], row_idx)
|
|
|
|
def move_row(self, src: int, tgt: int) -> None:
|
|
for block_table in self.block_tables:
|
|
block_table.move_row(src, tgt)
|
|
|
|
def swap_row(self, src: int, tgt: int) -> None:
|
|
for block_table in self.block_tables:
|
|
block_table.swap_row(src, tgt)
|
|
|
|
def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None:
|
|
for block_table in self.block_tables:
|
|
block_table.compute_slot_mapping(req_indices, positions)
|
|
|
|
def commit_block_table(self, num_reqs: int) -> None:
|
|
for block_table in self.block_tables:
|
|
block_table.commit_block_table(num_reqs)
|
|
|
|
def commit_slot_mapping(self, num_tokens: int) -> None:
|
|
for block_table in self.block_tables:
|
|
block_table.commit_slot_mapping(num_tokens)
|
|
|
|
def clear(self) -> None:
|
|
for block_table in self.block_tables:
|
|
block_table.clear()
|
|
|
|
def __getitem__(self, idx: int) -> "BlockTable":
|
|
"""Returns the BlockTable for the i-th KV cache group."""
|
|
return self.block_tables[idx]
|