Files
xc-llm-ascend/vllm_ascend/worker/block_table.py
Angazenn ce5544bfc1 [Hybrid] support prefix cache for Qwen3.5/Next with --mamba-cache-mode align (#7103)
### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](https://github.com/vllm-project/vllm/pull/30877) and inherits
changes to functions which are overridden in vLLM-Ascend.

Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.

### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Angazenn <supperccell@163.com>
2026-03-15 09:44:09 +08:00

321 lines
14 KiB
Python

import numpy as np
import torch
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size
class BlockTable:
def __init__(
self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_sizes: list[int] | None = None,
cp_kv_cache_interleave_size: int = 1,
num_speculative_tokens: int = 0,
):
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
self.physical_block_size = block_size
try:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_world_size > 1 else 0
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.pcp_world_size = 1
self.pcp_rank = 0
# If kernel_sizes is None or [0], use physical block size (no splitting)
if kernel_sizes is None or kernel_sizes == [0]:
self.block_size = block_size
self.logical_block_size = block_size
self.blocks_per_phys_block = 1
self.use_hybrid_blocks = False
else:
# Find the first kernel size that divides physical_block_size evenly
selected_kernel_size = None
for kernel_size in kernel_sizes:
if kernel_size > 0 and self.physical_block_size % kernel_size == 0:
selected_kernel_size = kernel_size
break
if selected_kernel_size is None:
raise ValueError(
f"None of the kernel sizes {kernel_sizes} can divide "
f"physical block size {self.physical_block_size} evenly"
)
self.block_size = selected_kernel_size
self.logical_block_size = selected_kernel_size
self.blocks_per_phys_block = self.physical_block_size // self.logical_block_size
if self.blocks_per_phys_block > 1:
self.use_hybrid_blocks = True
else:
self.use_hybrid_blocks = False
if self.use_hybrid_blocks:
logical_table_size = max_num_blocks_per_req * self.blocks_per_phys_block
else:
logical_table_size = max_num_blocks_per_req
duplicate_size = 1
if self.pcp_world_size * self.dcp_world_size > 1:
duplicate_size += num_speculative_tokens
self.block_table = self._make_buffer(max_num_reqs * duplicate_size, logical_table_size, dtype=torch.int32)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens + 2 * self.pcp_world_size * self.max_num_reqs, dtype=torch.int32
)
self.kernel_sizes = kernel_sizes
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row(
self,
block_ids,
row_idx: int,
) -> None:
if not block_ids:
return
block_ids = np.array(block_ids)
if self.use_hybrid_blocks:
block_ids = self._convert_physical_to_logical_blocks(block_ids)
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.block_table.np[row_idx, start : start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] += num_blocks
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table.np[tgt, :num_blocks] = self.block_table.np[src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
self.num_blocks_per_row[src] = num_blocks_tgt
self.num_blocks_per_row[tgt] = num_blocks_src
self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]]
def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
if self.dcp_world_size * self.pcp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % pcp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size * self.pcp_world_size
# IMPORTANT: In hybrid mode, positions are in logical block space,
# but we need to map them to the correct logical block table indices
logical_block_idx = positions // virtual_block_size
# Account for the expanded logical table
# (always needed with unified tensor)
# Each physical block is split into multiple logical blocks
# The logical table has been expanded to accommodate this
block_table_indices = (
req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
self.current_rank = self.dcp_world_size * self.pcp_rank + self.dcp_rank
mask = (
virtual_block_offsets // self.cp_kv_cache_interleave_size % (self.dcp_world_size * self.pcp_world_size)
== self.current_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (self.dcp_world_size * self.pcp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[: req_indices.shape[0]] = np.where(mask, slot_mapping, -1)
else:
assert self.kernel_sizes is not None
if self.block_size == self.kernel_sizes[0]:
# IMPORTANT: In hybrid mode, positions are in logical block space,
# but we need to map them to the correct logical block table indices
logical_block_idx = positions // self.block_size
# Account for the expanded logical table
# (always needed with unified tensor)
# Each physical block is split into multiple logical blocks
# The logical table has been expanded to accommodate this
block_table_indices = (
req_indices * self.max_num_blocks_per_req * self.blocks_per_phys_block + logical_block_idx
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(block_numbers * self.block_size, block_offsets, out=self.slot_mapping.np[: req_indices.shape[0]])
def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping.copy_to_gpu(num_tokens)
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table.cpu.fill_(0)
def _convert_physical_to_logical_blocks(self, physical_blocks: np.ndarray) -> np.ndarray:
"""Convert physical block IDs to logical block IDs."""
if not self.use_hybrid_blocks:
return physical_blocks
# Create logical block IDs by splitting each physical block
logical_blocks: list[int] = []
for phys_block in physical_blocks:
# Convert physical block to multiple logical blocks
# Physical block 1 becomes logical blocks
# [1*split_ratio, 1*split_ratio+1, ...]
# But we need to account for the fact that block 0 is special
base_logical = phys_block * self.blocks_per_phys_block
logical_blocks.extend(range(base_logical, base_logical + self.blocks_per_phys_block))
return np.array(logical_blocks, dtype=np.int32)
def get_device_tensor(self) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table.gpu
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table.cpu
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table.np
def _make_buffer(self, *size: int | torch.SymInt, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
max_num_blocks: list[int] | None = None,
kernel_sizes: list[list[int]] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
if kernel_sizes is None:
kernel_sizes = [[0]] * len(block_sizes)
# Ensure kernel_sizes matches block_sizes length
elif len(kernel_sizes) == 1 and len(block_sizes) > 1:
kernel_sizes = kernel_sizes * len(block_sizes)
elif len(kernel_sizes) != len(block_sizes):
raise ValueError(
f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})"
)
if max_num_blocks is None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
total_cp_world_size = get_total_cp_world_size()
max_num_blocks = [cdiv(max_model_len, block_size * total_cp_world_size) for block_size in block_sizes]
if len(max_num_blocks) != len(block_sizes):
raise ValueError(
f"max_num_blocks length ({len(max_num_blocks)}) must match block_sizes length ({len(block_sizes)})"
)
# Use zip to pair block_sizes with kernel_sizes one-to-one
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max_num_blocks_per_req,
max_num_batched_tokens,
pin_memory,
device,
kernel_size_list,
cp_kv_cache_interleave_size,
num_speculative_tokens,
)
for block_size, kernel_size_list, max_num_blocks_per_req in zip(block_sizes, kernel_sizes, max_num_blocks)
]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def compute_slot_mapping(self, req_indices: np.ndarray, positions: np.ndarray) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)
def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit_block_table(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]