[New model] Qwen3-next support (#2917)

### What this PR does / why we need it?
Add Qwen3-next support.

### Does this PR introduce _any_ user-facing change?
Yes, users can use Qwen3 next.
Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the
tutorial will be ready in
[here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html)

### How was this patch tested?
Doc CI passed

Related: https://github.com/vllm-project/vllm-ascend/issues/2884

Co-Authored-By: Angazenn <supperccell@163.com>
Co-Authored-By: zzzzwwjj <1183291235@qq.com>
Co-Authored-By: MengqingCao <cmq0113@163.com>
Co-Authored-By: linfeng-yuan <1102311262@qq.com>
Co-Authored-By: hust17yixuan <303660421@qq.com>
Co-Authored-By: SunnyLee219 <3294305115@qq.com>
Co-Authored-By: maoxx241 <maoxx241@umn.edu>


- vLLM version: v0.10.2
- vLLM main:
b834b4cbf1

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Your Name <you@example.com>
Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: hust17yixuan <303660421@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: zzzzwwjj <1183291235@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
wangxiyuan
2025-09-16 01:17:42 +08:00
committed by GitHub
parent b5ccef6115
commit c556038ef0
26 changed files with 3959 additions and 258 deletions

View File

@@ -19,11 +19,14 @@
import copy
import gc
import math
import itertools
import time
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import numpy as np
import numpy.typing as npt
@@ -33,10 +36,12 @@ import torch.distributed as dist
import torch.nn as nn
from tqdm import tqdm # type: ignore
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
@@ -46,7 +51,8 @@ from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
is_global_first_rank)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_transcription
@@ -59,28 +65,32 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv, is_pin_memory_available)
LazyLoader, cdiv, get_dtype_size,
is_pin_memory_available)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import \
reorder_batch_to_split_decodes_and_prefills
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.multistream.ms_split import compute_split_seq_index
@@ -91,8 +101,6 @@ from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p,
@@ -241,14 +249,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
from vllm.v1.sample.sampler import Sampler
self.sampler = Sampler()
self.reorder_batch_threshold: Optional[int] = None
# Lazy initialization, these will be set after __init__
self.kv_caches: List[torch.Tensor] = []
self.attn_groups: list[list[AttentionGroup]] = []
self.encoder_cache: Dict[str, torch.Tensor] = {}
self.attn_mask = None
self.attn_state = None
self.requests: Dict[str, CachedRequestState] = {}
self.intermediate_tensors: Optional[IntermediateTensors] = None
self.runner_only_attn_layers: set[str] = set()
ascend_config = get_ascend_config()
if ascend_config.ascend_scheduler_config.enabled:
@@ -279,8 +290,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
vllm_config, device)
self.attn_mask_builder = AttentionMaskBuilder(
self.model_config.max_model_len, self.dtype)
@@ -412,6 +422,73 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.async_output_copy_stream = torch.npu.Stream() if \
self.use_async_scheduling else None
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
# reasons, we have to initialize the input batch before `load_model`,
# quantization + weight offloading will fail otherwise. As a temporary
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config, self.device, self.pin_memory,
self.is_pooling_model,
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
kernel_block_sizes=None,
)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
def _make_buffer(self,
*size: Union[int, torch.SymInt],
dtype: torch.dtype,
numpy: bool = True) -> CpuGpuBuffer:
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
# if a bfloat16 buffer is needed without a corresponding numpy array,
# don't bother instantiating the numpy array.
return CpuGpuBuffer(*size,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory,
with_numpy=numpy)
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor) -> None:
"""Update the cached states after model execution.
This is used for MTP/EAGLE for hybrid models, as in linear attention,
only the last token's state is kept. In MTP/EAGLE, for draft tokens
the state are kept util we decide how many tokens are accepted for
each sequence, and a shifting is done during the next iteration
based on the number of accepted tokens.
"""
if not self.model_config.is_hybrid or not self.speculative_config:
return
# Find the number of accepted tokens for each sequence.
num_accepted_tokens = (torch.cat(
[
output_token_ids,
torch.full((output_token_ids.size(0), 1),
-1,
device=output_token_ids.device),
],
dim=1) == -1).int().argmax(-1).cpu().numpy()
for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
def _use_aclgraph(self) -> bool:
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
@@ -611,7 +688,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Condense the batched states if there are gaps left by removed requests
self.input_batch.condense()
# Allow attention backend to reorder the batch, potentially
self._may_reorder_batch(scheduler_output)
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
@@ -970,22 +1048,42 @@ class NPUModelRunner(LoRAModelRunnerMixin):
src=self.input_batch.prev_sampled_token_ids[
prev_common_req_indices_tensor, 0])
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
"""
# Attention free models have zero kv_cache_goups, however models
# like Mamba are also attention free but use the kv_cache for
# keeping its internal state. This is why we check the number
# of kv_cache groups instead of solely checking
# for self.model_config.is_attention_free.
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
if self.reorder_batch_threshold is not None:
reorder_batch_to_split_decodes_and_prefills(
self.input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata,
AscendMLATorchairMetadata], torch.Tensor, np.ndarray, int,
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
self.attn_metadata_builder.reorder_batch(self.input_batch,
scheduler_output)
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit_block_table(num_reqs)
@@ -1088,9 +1186,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
self.input_batch.block_table[0].
slot_mapping_cpu[:total_num_scheduled_tokens])
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
@@ -1131,32 +1226,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
block_table_tensor=self.input_batch.block_table[0].
get_device_tensor(),
slot_mapping_cpu=self.slot_mapping_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
enable_dbo_across_dp=enable_dbo,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens,
graph_pad_size=self.graph_pad_size,
decode_token_per_req=self.decode_token_per_req,
)
attn_metadata = self.attn_metadata_builder.build(
common_attn_metadata, self.model)
if self.vllm_config.model_config.use_mla:
attn_metadata.num_input_tokens = num_input_tokens
attn_metadata: dict[str, Any] = {}
# Prepare input_ids
token_indices = (positions_np +
@@ -1238,6 +1308,90 @@ class NPUModelRunner(LoRAModelRunnerMixin):
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu()
# Used in the below loop.
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None
if use_spec_decode:
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()
slot_mapping = blk_table.slot_mapping_cpu[:
total_num_scheduled_tokens]
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
slot_mapping)
# # Fill unused with -1. Needed for reshape_and_cache in full cuda
# # graph mode.
# blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu,
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
actual_seq_lengths_q=self.actual_seq_lengths_q,
# TODO: change this to the right block table for linear attn
block_table_tensor=blk_table_tensor[:num_reqs],
slot_mapping_cpu=self.slot_mapping_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
positions=self.positions,
attn_mask=self.attn_mask,
spec_attn_mask=self.spec_attn_mask,
attn_state=self.attn_state,
enable_dbo_across_dp=enable_dbo,
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
max_query_len=max_num_scheduled_tokens,
graph_pad_size=self.graph_pad_size,
decode_token_per_req=self.decode_token_per_req,
)
if self.speculative_config and \
spec_decode_common_attn_metadata is None:
spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]:
common_prefix_len = 0
extra_attn_metadata_args = {}
builder = attn_group.metadata_builder
if isinstance(builder, GDNAttentionMetadataBuilder):
if use_spec_decode:
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_draft_tokens=self.num_draft_tokens.
gpu[:num_reqs],
)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
**extra_attn_metadata_args)
else:
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
model=self.model,
**extra_attn_metadata_args)
if self.vllm_config.model_config.use_mla:
attn_metadata_i.num_input_tokens = num_input_tokens
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if lmhead_tp_enable():
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
@@ -1453,9 +1607,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions: torch.Tensor,
num_scheduled_tokens: int,
hidden_states: torch.Tensor,
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata,
AscendMLATorchairMetadata],
attn_metadata: dict[str, Any],
aux_hidden_states: torch.Tensor = None,
) -> Optional[list[list[int]]]:
if not self.drafter:
@@ -1700,6 +1852,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
self._update_states_after_model_execute(output_token_ids)
discard_sampled_tokens_req_indices: list[int] = []
# TODO(woosuk): The following loop can be slow since it iterates over
@@ -2231,31 +2384,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
kv_caches: Dict[str, torch.Tensor] = {}
self.may_reinitialize_input_batch(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
if self.model_config.is_deepseek_mla:
kv_caches = self.initialize_kv_cache_tensors_deepseek(
kv_cache_config)
else:
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config, self.device, self.pin_memory,
self.is_pooling_model,
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
def initialize_kv_cache_tensors_deepseek(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
assert len(kv_cache_tensor.shared_by) == 1, (
@@ -2263,12 +2407,141 @@ class NPUModelRunner(LoRAModelRunnerMixin):
"NPU.")
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
):
attn_backend = kv_cache_group.backend
for layer_name in kv_cache_group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
tensor_size = kv_cache_sizes[layer_name]
assert tensor_size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
if self.vllm_config.additional_config.get(
"kv_cache_dtype", None) == 'int8':
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
elif hasattr(attn_backend, "get_supported_block_size"
) and not self.model_config.is_deepseek_mla:
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk, block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
alignment = 2 * 1024 * 1024
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
nope_dim)
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
rope_dim)
if self.vllm_config.kv_transfer_config is None:
# For no disaggregate pd scenario, allocate kv cache in normal way
rope_cache = torch.zeros(rope_cache_shape,
dtype=dtype,
device=self.device)
nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype,
device=self.device)
rope_cache = self._convert_torch_format(rope_cache)
nope_cache = self._convert_torch_format(nope_cache)
else:
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
# we found there are also some exceptions during test, so we manual align those memory here, this part
# of code may consume 2M * 2 * elem_size memory every layer.
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
nope_allocate_shape_alignment = nope_allocate_shape + alignment
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
rope_allocate_shape_alignment = rope_allocate_shape + alignment
nope_cache = torch.zeros(nope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
rope_cache = torch.zeros(rope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
nope_cache = align_memory(
nope_cache,
alignment)[:nope_allocate_shape].view(nope_cache_shape)
rope_cache = align_memory(
rope_cache,
alignment)[:rope_allocate_shape].view(rope_cache_shape)
kv_caches[layer_name] = (nope_cache, rope_cache)
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.
Args:
kv_cache_config: The KV cache config
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# init kv cache tensors
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
# TODO: REFACTOR ME to sharing hybrid cache
for idx in range(len(kv_cache_tensor.shared_by)):
layer_name = kv_cache_tensor.shared_by[idx]
if "linear_attn" in layer_name:
for layer_name_inner in kv_cache_tensor.shared_by:
if "self_attn" in layer_name_inner or layer_name_inner in kv_cache_raw_tensors.keys(
):
continue
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
kv_cache_raw_tensors[layer_name_inner] = tensor
elif "self_attn" in layer_name:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys(
)), "Some layers are not correctly initialized"
kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
):
attn_backend = kv_cache_group.backend
for layer_name in kv_cache_group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel(
) // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
@@ -2278,100 +2551,228 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
alignment = 2 * 1024 * 1024
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, FullAttentionSpec):
if self.vllm_config.additional_config.get(
"kv_cache_dtype", None) == 'int8':
kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape(
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
elif hasattr(attn_backend, "get_supported_block_size"
) and not self.model_config.is_deepseek_mla:
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk, block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
if self.model_config.is_deepseek_mla:
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim
nope_cache_shape = (num_blocks, block_size,
num_kv_heads, nope_dim)
rope_cache_shape = (num_blocks, block_size,
num_kv_heads, rope_dim)
if self.vllm_config.kv_transfer_config is None:
# For no disaggregate pd scenario, allocate kv cache in normal way
rope_cache = torch.zeros(rope_cache_shape,
dtype=dtype,
device=self.device)
nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype,
device=self.device)
rope_cache = self._convert_torch_format(rope_cache)
nope_cache = self._convert_torch_format(nope_cache)
else:
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
# we found there are also some exceptions during test, so we manual align those memory here, this part
# of code may consume 2M * 2 * elem_size memory every layer.
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
nope_allocate_shape_alignment = nope_allocate_shape + alignment
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
rope_allocate_shape_alignment = rope_allocate_shape + alignment
nope_cache = torch.zeros(
nope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
rope_cache = torch.zeros(
rope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
nope_cache = align_memory(
nope_cache,
alignment)[:nope_allocate_shape].view(
nope_cache_shape)
rope_cache = align_memory(
rope_cache,
alignment)[:rope_allocate_shape].view(
rope_cache_shape)
kv_caches[layer_name] = (nope_cache, rope_cache)
else:
num_caches = kv_cache_shape[0]
kv_cache_list = []
for i in range(num_caches):
cache_shape = kv_cache_shape[1:]
if self.vllm_config.kv_transfer_config is None:
kv_cache = torch.zeros(cache_shape,
dtype=dtype,
device=self.device)
kv_cache = self._convert_torch_format(kv_cache)
else:
cache_size = math.prod(cache_shape)
cache_size_aligned = cache_size + alignment
kv_cache = torch.zeros(cache_size_aligned,
dtype=dtype,
device=self.device)
kv_cache = align_memory(
kv_cache,
alignment)[:cache_size].view(cache_shape)
kv_cache_list.append(kv_cache)
kv_caches[layer_name] = tuple(kv_cache_list)
kv_cache = raw_tensor.view(dtype).view(kv_cache_shape)
kv_cache = self._convert_torch_format(kv_cache)
kv_caches[layer_name] = kv_cache
elif isinstance(kv_cache_spec, MambaSpec):
raw_tensor = kv_cache_raw_tensors[layer_name]
state_tensors = []
storage_offset_bytes = 0
for (shape, dtype) in zip(kv_cache_spec.shapes,
kv_cache_spec.dtypes):
dtype_size = get_dtype_size(dtype)
num_element_per_page = (
kv_cache_spec.page_size_bytes // dtype_size)
target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:])
assert storage_offset_bytes % dtype_size == 0
tensor = torch.as_strided(
raw_tensor.view(dtype),
size=target_shape,
stride=target_stride,
storage_offset=storage_offset_bytes // dtype_size,
)
state_tensors.append(tensor)
storage_offset_bytes += stride[0] * dtype_size
kv_caches[layer_name] = state_tensors
else:
# TODO: add new branches when introducing more types of
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
return kv_caches
def _kv_cache_spec_attn_group_iterator(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.
Args:
kv_cache_config: The KV cache configuration.
"""
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
# Generate kernel_block_sizes that matches each block_size
# For attention backends that support virtual block splitting,
# use the supported block sizes from the backend
# For other backends (like Mamba), use [0] (no splitting)
kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate(
kv_cache_config.kv_cache_groups):
if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual
# block splitting. Get the supported block sizes from
# the backend.
try:
attn_groups = self.attn_groups[kv_cache_group_id]
except IndexError:
attn_groups = None
if attn_groups:
# Use the backend's supported block size list
backend = attn_groups[0].backend
supported_sizes = backend.get_supported_block_size()
# If no specific sizes supported, use cache config
# block_size
kernel_block_size_list = (supported_sizes
if supported_sizes else
[self.cache_config.block_size])
else:
# Fallback to cache config block_size if no backend found
kernel_block_size_list = [
64
] if not self.model_config.is_deepseek_mla else [0]
kernel_block_sizes.append(kernel_block_size_list)
else:
# This is likely Mamba or other non-attention cache,
# no splitting.
kernel_block_sizes.append([0])
if kernel_block_sizes != [self.cache_config.block_size]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=self.input_batch.logitsprocs,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config else 0),
kernel_block_sizes=kernel_block_sizes,
)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
"""
assert len(self.attn_groups) == 0, \
"Attention backends are already initialized"
def get_attn_backends_for_layers(
layer_names: list[str]
) -> dict[type[AttentionBackend], list[str]]:
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than
# using the class itself as the key because when we create dynamic
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
# they are cached correctly, there will be different objects per
# layer.
for layer_name in layer_names:
attn_backend = layers[layer_name].get_attn_backend()
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
return {
attn_backends[k]: v
for k, v in attn_backend_layers.items()
}
def create_attn_groups(
attn_backends_map: dict[AttentionBackend, list[str]],
kv_cache_spec: KVCacheSpec,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for attn_backend, layer_names in attn_backends_map.items():
attn_metadata_builder_i = attn_backend.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
attn_group = AttentionGroup(attn_backend,
attn_metadata_builder_i,
layer_names)
attn_groups.append(attn_group)
return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
self.attn_groups.append(
create_attn_groups(attn_backends, kv_cache_spec))
# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups)
def calculate_reorder_batch_threshold(self) -> None:
"""
Check that if any backends reorder batches; that the reordering
is compatible (e.g., decode threshold is the same)
"""
for group in self._attn_group_iterator():
attn_metadata_builder_i = group.metadata_builder
if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"):
# check that if any backends reorder batches; that the reordering
# is compatible (e.g., decode threshold is the same)
reorder_batch_threshold_i = (
attn_metadata_builder_i.reorder_batch_threshold)
if reorder_batch_threshold_i is not None:
if self.reorder_batch_threshold is not None:
if reorder_batch_threshold_i != \
self.reorder_batch_threshold:
raise ValueError(
f"Attention backend reorders decodes with "
f"threshold {reorder_batch_threshold_i} but other "
f"backend uses threshold "
f"{self.reorder_batch_threshold}")
else:
self.reorder_batch_threshold = reorder_batch_threshold_i
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
@@ -2382,19 +2783,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
forward_ctx = self.compilation_config.static_forward_context
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer :=
attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention
assert isinstance(attn_module, Attention)
# TODO: Support other attention modules, e.g., cross-attention
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=self.block_size,
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
@@ -2409,6 +2820,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
if len(mamba_layers) > 0:
if (self.vllm_config.speculative_config is not None
and self.vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = self.vllm_config.model_config.max_model_len
page_size_padded = (
self.vllm_config.cache_config.mamba_page_size_padded)
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtypes=mamba_module.get_state_dtype(),
block_size=max_model_len,
page_size_padded=page_size_padded,
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
self.speculative_config.num_speculative_tokens
if self.speculative_config else 0),
)
return kv_cache_spec
def initialize_aclgraph_capture(self) -> None: