Files
xc-llm-ascend/vllm_ascend/worker/npu_input_batch.py
Angazenn ce5544bfc1 [Hybrid] support prefix cache for Qwen3.5/Next with --mamba-cache-mode align (#7103)
### What this PR does / why we need it?
To support prefix cache for Qwen3.5/Next in vLLM-Ascend, this PR mainly
follows the design in
[#30877](https://github.com/vllm-project/vllm/pull/30877) and inherits
changes to functions which are overridden in vLLM-Ascend.

Note:
1. `--mamba-cache-mode align` && PD disaggregation is still not
supported yet in vLLM v0.17.0(see
https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/sched/scheduler.py#L295).
2. The current implementation of hybrid kv cache might result in a very
large block_size when scheduling. For example, if we run Qwen3.5-35B-A3B
with `-tp 2`, the block_size is adjusted to 2048, which means that any
prefix shorter than 2048 will never be cached. Although this behavior is
consistent with vLLM, it still needs improvements in the future.
3. `--mamba-cache-mode align` requires to copy mamba states during
forward steps. vLLM uses a triton kernel to implement it. However, the
original version run into some bugs on Ascend hardwares. Thus we patch a
new triton kernel to avoid this bug.

### Does this PR introduce _any_ user-facing change?
To use mamba prefix cache, set `--enable-prefix-caching` and
`--mamba-cache-mode align`. Note that the mamba state copy function(see
[do_mamba_copy_block](https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/mamba_utils.py#L132))
does not provide a torch native version, thus it might have trouble if
users can't use triton.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: Angazenn <supperccell@163.com>
2026-03-15 09:44:09 +08:00

216 lines
9.5 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
#
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingStates
from vllm.v1.sample.logits_processor import BatchUpdateBuilder, LogitsProcessors
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_ascend.worker.block_table import MultiGroupBlockTable
class NPUInputBatch(InputBatch):
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
kernel_block_sizes: list[list[int]],
max_num_blocks_per_req: list[int] | None = None,
logitsprocs: LogitsProcessors | None = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: list[str | None] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids_tensor = torch.zeros(
(max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False
)
self.is_token_ids = self.is_token_ids_tensor.numpy()
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
max_num_blocks=max_num_blocks_per_req,
num_speculative_tokens=num_speculative_tokens,
kernel_sizes=kernel_block_sizes,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
# Sampling-related.
self.temperature = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
self.temperature_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory
)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs,), dtype=torch.float32, device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs,), dtype=torch.int32, device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs,), dtype=torch.float, device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs,), dtype=torch.float, device=device)
self.presence_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs,), dtype=torch.float, device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs,), dtype=torch.float, device="cpu", pin_memory=pin_memory
)
self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones(
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs,), dtype=np.int64)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: torch.Tensor | None = None
self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool)
self.req_output_token_ids: list[list[int] | None] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
# Store last speculative tokens for sampler.
self.spec_token_ids: list[list[int]] = [[] for _ in range(max_num_reqs)]
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
# for pooling models
self.pooling_params: dict[str, PoolingParams] = {}
self.pooling_states: dict[str, PoolingStates] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: torch.Tensor | None = None
self.prev_req_id_to_index: dict[str, int] | None = None
# These are used to update output_token_ids with real sampled
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.Event | None = None