[Hybrid] support prefix cache for Qwen3.5/Next with --mamba-cache-mode align (#7103)

### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](https://github.com/vllm-project/vllm/pull/30877) and inherits
changes to functions which are overridden in vLLM-Ascend.

Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.

### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2026-03-15 09:44:09 +08:00
committed by GitHub
parent c69291eefc
commit ce5544bfc1
8 changed files with 173 additions and 17 deletions

View File

@@ -77,6 +77,10 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker import mamba_utils
from vllm.v1.worker.cp_utils import (
get_total_cp_world_size,
)
from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
@@ -416,6 +420,8 @@ class NPUModelRunner(GPUModelRunner):
self.cudagraph_batch_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
else:
self.cudagraph_batch_sizes = []
self.mamba_state_idx: dict[str, int] = {}
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None
@property
def use_cp(self) -> bool:
@@ -1250,6 +1256,23 @@ class NPUModelRunner(GPUModelRunner):
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
# NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877,
# there should be a corresponding 'postprocess_mamba'. However, it is called inside
# '_update_states_after_model_execute', which is not overridden in vLLM-Ascend.
# We simply utilize the implementation in vLLM.
if self.cache_config.mamba_cache_mode == "align":
mamba_utils.preprocess_mamba(
scheduler_output,
self.kv_cache_config,
self.cache_config,
self.mamba_state_idx,
self.input_batch,
self.requests,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(),
)
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
@@ -2542,6 +2565,7 @@ class NPUModelRunner(GPUModelRunner):
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self._mamba_copy_bufs = None
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
# NOTE(cmq): initialize_attn_backend must before using self.attn_groups
@@ -2983,6 +3007,21 @@ class NPUModelRunner(GPUModelRunner):
# of mamba block. In this case, BlockTable.block_size will never equal
# to kernel_block_sizes[0]
self.kernel_block_sizes.append([0])
max_num_blocks = []
max_model_len = max(self.max_model_len, self.max_encoder_len)
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
max_num_blocks_per_req = cdiv(max_model_len, block_sizes[i] * get_total_cp_world_size())
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
mamba_blocks_per_req = (
max_num_blocks_per_req if self.cache_config.enable_prefix_caching else 1
) + kv_cache_group.kv_cache_spec.num_speculative_blocks
max_num_blocks_per_req = max(max_num_blocks_per_req, mamba_blocks_per_req)
max_num_blocks.append(max_num_blocks_per_req)
if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
@@ -2991,7 +3030,7 @@ class NPUModelRunner(GPUModelRunner):
)
self.input_batch = NPUInputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=max(self.model_config.max_model_len, self.max_encoder_len),
max_model_len=max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
@@ -3006,6 +3045,7 @@ class NPUModelRunner(GPUModelRunner):
else 0
),
kernel_block_sizes=self.kernel_block_sizes,
max_num_blocks_per_req=max_num_blocks,
)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
@@ -3171,8 +3211,6 @@ class NPUModelRunner(GPUModelRunner):
mamba_layers[layer_name] = attn_module
if len(mamba_layers) > 0:
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError("Prefix caching is not supported for Mamba yet.")
mamba_page_size_padded = 0
for layer_name, mamba_module in mamba_layers.items():
if spec := mamba_module.get_kv_cache_spec(self.vllm_config):