From ce5544bfc1df73ecbe0cb9b7b5fa76bc26b5b5d8 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Sun, 15 Mar 2026 09:44:09 +0800 Subject: [PATCH] [Hybrid] support prefix cache for Qwen3.5/Next with `--mamba-cache-mode align` (#7103) ### What this PR does / why we need it? To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly follows the design in [#30877](https://github.com/vllm-project/vllm/pull/30877) and inherits changes to functions which are overridden in vLLM-Ascend. Note: 1. `--mamba-cache-mode align` && PD disaggregation is still not supported yet in vLLM v0.17.0(see https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295). 2. The current implementation of hybrid kv cache might result in a very large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B with `-tp 2`, the block_size is adjusted to 2048, which means that any prefix shorter than 2048 will never be cached. Although this behavior is consistent with vLLM, it still needs improvements in the future. 3. `--mamba-cache-mode align` requires to copy mamba states during forward steps. vLLM uses a triton kernel to implement it. However, the original version run into some bugs on Ascend hardwares. Thus we patch a new triton kernel to avoid this bug. ### Does this PR introduce _any_ user-facing change? To use mamba prefix cache, set `--enable-prefix-caching` and `--mamba-cache-mode align`. Note that the mamba state copy function(see [do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132)) does not provide a torch native version, thus it might have trouble if users can't use triton. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: Angazenn --- .../triton/test_batch_memcpy.py | 38 ++++++++++++++++ vllm_ascend/ops/triton/batch_memcpy.py | 31 +++++++++++++ vllm_ascend/patch/__init__.py | 22 ++++++++++ vllm_ascend/patch/worker/__init__.py | 1 + vllm_ascend/patch/worker/patch_mamba_utils.py | 21 +++++++++ vllm_ascend/worker/block_table.py | 31 +++++++------ vllm_ascend/worker/model_runner_v1.py | 44 +++++++++++++++++-- vllm_ascend/worker/npu_input_batch.py | 2 + 8 files changed, 173 insertions(+), 17 deletions(-) create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py create mode 100644 vllm_ascend/ops/triton/batch_memcpy.py create mode 100644 vllm_ascend/patch/worker/patch_mamba_utils.py diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py new file mode 100644 index 00000000..b5162a6d --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_batch_memcpy.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_batch_memcpy(dtype): + element_size = 2 if dtype == torch.bfloat16 else 4 + device = "npu:0" + # this is a typical case when used in mamba states copy. + sizes = torch.tensor([24576, 262144, 24576, 262144], device=device, dtype=torch.int32) + + src_tensors_list = [] + src_addr_list = [] + dst_tensors_list = [] + dst_addr_list = [] + for i in range(len(sizes)): + src_tensors_list.append( + torch.rand(sizes[i].item() // element_size, dtype=dtype, device=device) + ) + src_addr_list.append(src_tensors_list[-1].data_ptr()) + dst_tensors_list.append( + torch.empty(sizes[i].item() // element_size, dtype=dtype, device=device) + ) + dst_addr_list.append(dst_tensors_list[-1].data_ptr()) + + src_addr_list = torch.tensor(src_addr_list, dtype=torch.int64, device=device) + dst_addr_list = torch.tensor(dst_addr_list, dtype=torch.int64, device=device) + + batch = sizes.shape[0] + + grid = (batch,) + # using larger block_size to accelerate copy. + BLOCK_SIZE = 8192 + batch_memcpy_kernel[grid](src_addr_list, dst_addr_list, sizes, BLOCK_SIZE=BLOCK_SIZE) + + for i in range(len(sizes)): + torch.testing.assert_close(src_tensors_list[i], dst_tensors_list[i], rtol=0, atol=0) diff --git a/vllm_ascend/ops/triton/batch_memcpy.py b/vllm_ascend/ops/triton/batch_memcpy.py new file mode 100644 index 00000000..0bd576d3 --- /dev/null +++ b/vllm_ascend/ops/triton/batch_memcpy.py @@ -0,0 +1,31 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.triton_utils import tl, triton + + +@triton.jit +def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + + src_ptr = tl.load(src_ptrs + pid) + dst_ptr = tl.load(dst_ptrs + pid) + size = tl.load(sizes + pid) + + # We need to mv pointer_type cast outside the loop. + # Otherwise it causes potential bugs. + src_ptr = src_ptr.to(tl.pointer_type(tl.uint8)) + dst_ptr = dst_ptr.to(tl.pointer_type(tl.uint8)) + + offsets = tl.arange(0, BLOCK_SIZE) + for i in range(0, size, BLOCK_SIZE): + mask = (i + offsets) < size + + curr_src_ptr = src_ptr + i + offsets + curr_dst_ptr = dst_ptr + i + offsets + + # cache_modifier=".cg" bypasses L1 cache for streaming data. + data = tl.load(curr_src_ptr, mask=mask, cache_modifier=".cg") + tl.store(curr_dst_ptr, data, mask=mask, cache_modifier=".cg") diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4eb74013..41463c7b 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -506,3 +506,25 @@ # Rotary quant is a unique feature of vllm-ascend. # Future Plan: # Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`. +# ** 22. File: worker/patch_mamba_utils.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.worker.mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel` +# Why: +# Oringnal batch_memcpy_kernel implemented in vLLM might encounter bugs when running on +# Ascend hardwares. +# How: +# patch to fix related bugs. +# Future Plan: +# Remove this patch when: +# (1) oringnal batch_memcpy_kernel can run on Ascend hardware. +# or +# (2) design a dispatch mechanism for batch_memcpy_kernel. +# 2. `vllm.v1.worker.mamba_utils.batch_memcpy = batch_memcpy` +# Why: +# vLLM use BLOCK_SIZE 1024 for batch_memcpy_kernel. This results in suboptimal performance +# on Ascend hardwares. +# How: +# patch to change BLOCK_SIZE to 8192. +# Future Plan: +# Remove this patch when: +# design a dispatch mechanism for batch_memcpy_kernel. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index f7d509a2..06aa5d2a 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -28,6 +28,7 @@ import vllm_ascend.patch.worker.patch_bert # noqa import vllm_ascend.patch.worker.patch_distributed # noqa import vllm_ascend.patch.worker.patch_minimax_m2 # noqa import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa +import vllm_ascend.patch.worker.patch_mamba_utils # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa diff --git a/vllm_ascend/patch/worker/patch_mamba_utils.py b/vllm_ascend/patch/worker/patch_mamba_utils.py new file mode 100644 index 00000000..063789bf --- /dev/null +++ b/vllm_ascend/patch/worker/patch_mamba_utils.py @@ -0,0 +1,21 @@ +# mypy: ignore-errors + + +from vllm.v1.worker import mamba_utils + +from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel + + +def batch_memcpy(src_ptrs, dst_ptrs, sizes): + batch = src_ptrs.shape[0] + assert dst_ptrs.shape[0] == batch + assert sizes.shape[0] == batch + + grid = (batch,) + # using larger block_size to accelerate copy. + BLOCK_SIZE = 8192 + batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE) + + +mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel +mamba_utils.batch_memcpy = batch_memcpy diff --git a/vllm_ascend/worker/block_table.py b/vllm_ascend/worker/block_table.py index 4ffc7df6..3c812aa4 100644 --- a/vllm_ascend/worker/block_table.py +++ b/vllm_ascend/worker/block_table.py @@ -3,6 +3,7 @@ import torch from vllm.distributed import get_dcp_group, get_pcp_group from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.cp_utils import get_total_cp_world_size class BlockTable: @@ -239,21 +240,10 @@ class MultiGroupBlockTable: device: torch.device, block_sizes: list[int], num_speculative_tokens: int = 0, + max_num_blocks: list[int] | None = None, kernel_sizes: list[list[int]] | None = None, cp_kv_cache_interleave_size: int = 1, ) -> None: - # Note(hc): each dcp rank only store - # (max_model_len//dcp_world_size) tokens in kvcache, - # so the block_size which used for calc max_num_blocks_per_req - # must be multiplied by dcp_world_size. - try: - dcp_world_size = get_dcp_group().world_size - pcp_world_size = get_pcp_group().world_size - except AssertionError: - # DCP might not be initialized in testing - dcp_world_size = 1 - pcp_world_size = 1 - if kernel_sizes is None: kernel_sizes = [[0]] * len(block_sizes) # Ensure kernel_sizes matches block_sizes length @@ -264,12 +254,25 @@ class MultiGroupBlockTable: f"kernel_sizes length ({len(kernel_sizes)}) must match block_sizes length ({len(block_sizes)})" ) + if max_num_blocks is None: + # Note(hc): each dcp rank only store + # (max_model_len//dcp_world_size) tokens in kvcache, + # so the block_size which used for calc max_num_blocks_per_req + # must be multiplied by dcp_world_size. + total_cp_world_size = get_total_cp_world_size() + max_num_blocks = [cdiv(max_model_len, block_size * total_cp_world_size) for block_size in block_sizes] + + if len(max_num_blocks) != len(block_sizes): + raise ValueError( + f"max_num_blocks length ({len(max_num_blocks)}) must match block_sizes length ({len(block_sizes)})" + ) + # Use zip to pair block_sizes with kernel_sizes one-to-one self.block_tables = [ BlockTable( block_size, max_num_reqs, - max(cdiv(max_model_len, block_size * dcp_world_size * pcp_world_size), 1 + num_speculative_tokens), + max_num_blocks_per_req, max_num_batched_tokens, pin_memory, device, @@ -277,7 +280,7 @@ class MultiGroupBlockTable: cp_kv_cache_interleave_size, num_speculative_tokens, ) - for block_size, kernel_size_list in zip(block_sizes, kernel_sizes) + for block_size, kernel_size_list, max_num_blocks_per_req in zip(block_sizes, kernel_sizes, max_num_blocks) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 896e90e0..3c6d48bc 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -77,6 +77,10 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import record_function_or_nullcontext +from vllm.v1.worker import mamba_utils +from vllm.v1.worker.cp_utils import ( + get_total_cp_world_size, +) from vllm.v1.worker.gpu_model_runner import AsyncGPUModelRunnerOutput, GPUModelRunner from vllm.v1.worker.ubatch_utils import ( UBatchSlices, @@ -416,6 +420,8 @@ class NPUModelRunner(GPUModelRunner): self.cudagraph_batch_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) else: self.cudagraph_batch_sizes = [] + self.mamba_state_idx: dict[str, int] = {} + self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None @property def use_cp(self) -> bool: @@ -1250,6 +1256,23 @@ class NPUModelRunner(GPUModelRunner): pad_attn = cudagraph_mode == CUDAGraphMode.FULL + # NOTE(Angazenn): According to https://github.com/vllm-project/vllm/pull/30877, + # there should be a corresponding 'postprocess_mamba'. However, it is called inside + # '_update_states_after_model_execute', which is not overridden in vLLM-Ascend. + # We simply utilize the implementation in vLLM. + if self.cache_config.mamba_cache_mode == "align": + mamba_utils.preprocess_mamba( + scheduler_output, + self.kv_cache_config, + self.cache_config, + self.mamba_state_idx, + self.input_batch, + self.requests, + self.compilation_config.static_forward_context, + self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), + ) + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -2542,6 +2565,7 @@ class NPUModelRunner(GPUModelRunner): """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) # NOTE(cmq): initialize_attn_backend must before using self.attn_groups @@ -2983,6 +3007,21 @@ class NPUModelRunner(GPUModelRunner): # of mamba block. In this case, BlockTable.block_size will never equal # to kernel_block_sizes[0] self.kernel_block_sizes.append([0]) + + max_num_blocks = [] + max_model_len = max(self.max_model_len, self.max_encoder_len) + for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): + if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec): + continue + max_num_blocks_per_req = cdiv(max_model_len, block_sizes[i] * get_total_cp_world_size()) + if isinstance(kv_cache_group.kv_cache_spec, MambaSpec): + mamba_blocks_per_req = ( + max_num_blocks_per_req if self.cache_config.enable_prefix_caching else 1 + ) + kv_cache_group.kv_cache_spec.num_speculative_blocks + + max_num_blocks_per_req = max(max_num_blocks_per_req, mamba_blocks_per_req) + max_num_blocks.append(max_num_blocks_per_req) + if block_sizes != [self.cache_config.block_size] or self.kernel_block_sizes != [[self.cache_config.block_size]]: assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " @@ -2991,7 +3030,7 @@ class NPUModelRunner(GPUModelRunner): ) self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, - max_model_len=max(self.model_config.max_model_len, self.max_encoder_len), + max_model_len=max_model_len, max_num_batched_tokens=self.max_num_tokens, device=self.device, pin_memory=self.pin_memory, @@ -3006,6 +3045,7 @@ class NPUModelRunner(GPUModelRunner): else 0 ), kernel_block_sizes=self.kernel_block_sizes, + max_num_blocks_per_req=max_num_blocks, ) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: @@ -3171,8 +3211,6 @@ class NPUModelRunner(GPUModelRunner): mamba_layers[layer_name] = attn_module if len(mamba_layers) > 0: - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError("Prefix caching is not supported for Mamba yet.") mamba_page_size_padded = 0 for layer_name, mamba_module in mamba_layers.items(): if spec := mamba_module.get_kv_cache_spec(self.vllm_config): diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 2d7a7c8b..b5f21a6a 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -40,6 +40,7 @@ class NPUInputBatch(InputBatch): vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group kernel_block_sizes: list[list[int]], + max_num_blocks_per_req: list[int] | None = None, logitsprocs: LogitsProcessors | None = None, logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, @@ -97,6 +98,7 @@ class NPUInputBatch(InputBatch): pin_memory=pin_memory, device=device, block_sizes=block_sizes, + max_num_blocks=max_num_blocks_per_req, num_speculative_tokens=num_speculative_tokens, kernel_sizes=kernel_block_sizes, cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,