[Hybrid] support prefix cache for Qwen3.5/Next with --mamba-cache-mode align (#7103)

### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](https://github.com/vllm-project/vllm/pull/30877) and inherits
changes to functions which are overridden in vLLM-Ascend.

Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.

### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-03-15 09:44:09 +08:00
committed by GitHub
parent c69291eefc
commit ce5544bfc1
8 changed files with 173 additions and 17 deletions

View File

@@ -3,6 +3,7 @@ import torch
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size
class BlockTable:
@@ -239,21 +240,10 @@ class MultiGroupBlockTable:
device: torch.device,
block_sizes: list[int],
num_speculative_tokens: int = 0,
max_num_blocks: list[int] | None = None,
kernel_sizes: list[list[int]] | None = None,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
try:
dcp_world_size = get_dcp_group().world_size
pcp_world_size = get_pcp_group().world_size
except AssertionError:
# DCP might not be initialized in testing
dcp_world_size = 1
pcp_world_size = 1
if kernel_sizes is None:
kernel_sizes = [[0]] * len(block_sizes)
# Ensure kernel_sizes matches block_sizes length
@@ -264,12 +254,25 @@ class MultiGroupBlockTable:
f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})"
)
if max_num_blocks is None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
# so the block_size which used for calc max_num_blocks_per_req
# must be multiplied by dcp_world_size.
total_cp_world_size = get_total_cp_world_size()
max_num_blocks = [cdiv(max_model_len, block_size * total_cp_world_size) for block_size in block_sizes]
if len(max_num_blocks) != len(block_sizes):
raise ValueError(
f"max_num_blocks length ({len(max_num_blocks)}) must match block_sizes length ({len(block_sizes)})"
)
# Use zip to pair block_sizes with kernel_sizes one-to-one
self.block_tables = [
BlockTable(
block_size,
max_num_reqs,
max(cdiv(max_model_len, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens),
max_num_blocks_per_req,
max_num_batched_tokens,
pin_memory,
device,
@@ -277,7 +280,7 @@ class MultiGroupBlockTable:
cp_kv_cache_interleave_size,
num_speculative_tokens,
)
for block_size, kernel_size_list in zip(block_sizes, kernel_sizes)
for block_size, kernel_size_list, max_num_blocks_per_req in zip(block_sizes, kernel_sizes, max_num_blocks)
]
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: