Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_mamba_utils.py

22 lines
551 B
Python
Raw Permalink Normal View History

[Hybrid] support prefix cache for Qwen3.5/Next with `--mamba-cache-mode align` (#7103) ### What this PR does / why we need it? To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly follows the design in [#30877](https://github.com/vllm-project/vllm/pull/30877) and inherits changes to functions which are overridden in vLLM-Ascend. Note: 1. `--mamba-cache-mode align` && PD disaggregation is still not supported yet in vLLM v0.17.0(see https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295). 2. The current implementation of hybrid kv cache might result in a very large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B with `-tp 2`, the block_size is adjusted to 2048, which means that any prefix shorter than 2048 will never be cached. Although this behavior is consistent with vLLM, it still needs improvements in the future. 3. `--mamba-cache-mode align` requires to copy mamba states during forward steps. vLLM uses a triton kernel to implement it. However, the original version run into some bugs on Ascend hardwares. Thus we patch a new triton kernel to avoid this bug. ### Does this PR introduce _any_ user-facing change? To use mamba prefix cache, set `--enable-prefix-caching` and `--mamba-cache-mode align`. Note that the mamba state copy function(see [do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132)) does not provide a torch native version, thus it might have trouble if users can't use triton. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: Angazenn <supperccell@163.com>
2026-03-15 09:44:09 +08:00
# mypy: ignore-errors
from vllm.v1.worker import mamba_utils
from vllm_ascend.ops.triton.batch_memcpy import batch_memcpy_kernel
def batch_memcpy(src_ptrs, dst_ptrs, sizes):
batch = src_ptrs.shape[0]
assert dst_ptrs.shape[0] == batch
assert sizes.shape[0] == batch
grid = (batch,)
# using larger block_size to accelerate copy.
BLOCK_SIZE = 8192
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)
mamba_utils.batch_memcpy_kernel = batch_memcpy_kernel
mamba_utils.batch_memcpy = batch_memcpy