[New model] Qwen3-next support (#2917)
### What this PR does / why we need it?
Add Qwen3-next support.
### Does this PR introduce _any_ user-facing change?
Yes, users can use Qwen3 next.
Related doc: https://github.com/vllm-project/vllm-ascend/pull/2916 the
tutorial will be ready in
[here](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_npu_qwen3_next.html)
### How was this patch tested?
Doc CI passed
Related: https://github.com/vllm-project/vllm-ascend/issues/2884
Co-Authored-By: Angazenn <supperccell@163.com>
Co-Authored-By: zzzzwwjj <1183291235@qq.com>
Co-Authored-By: MengqingCao <cmq0113@163.com>
Co-Authored-By: linfeng-yuan <1102311262@qq.com>
Co-Authored-By: hust17yixuan <303660421@qq.com>
Co-Authored-By: SunnyLee219 <3294305115@qq.com>
Co-Authored-By: maoxx241 <maoxx241@umn.edu>
- vLLM version: v0.10.2
- vLLM main:
b834b4cbf1
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Your Name <you@example.com>
Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: hust17yixuan <303660421@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: Angazenn <supperccell@163.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: zzzzwwjj <1183291235@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
@@ -19,11 +19,14 @@
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -33,10 +36,12 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm # type: ignore
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
@@ -46,7 +51,8 @@ from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
is_global_first_rank)
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import supports_transcription
|
||||
@@ -59,28 +65,32 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LazyLoader, cdiv, is_pin_memory_available)
|
||||
LazyLoader, cdiv, get_dtype_size,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import \
|
||||
reorder_batch_to_split_decodes_and_prefills
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
KVCacheConfig, KVCacheSpec, MambaSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
|
||||
from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
||||
gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
@@ -91,8 +101,6 @@ from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
|
||||
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
get_ascend_soc_version, is_310p,
|
||||
@@ -241,14 +249,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
self.sampler = Sampler()
|
||||
self.reorder_batch_threshold: Optional[int] = None
|
||||
|
||||
# Lazy initialization, these will be set after __init__
|
||||
self.kv_caches: List[torch.Tensor] = []
|
||||
self.attn_groups: list[list[AttentionGroup]] = []
|
||||
self.encoder_cache: Dict[str, torch.Tensor] = {}
|
||||
self.attn_mask = None
|
||||
self.attn_state = None
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
self.runner_only_attn_layers: set[str] = set()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.ascend_scheduler_config.enabled:
|
||||
@@ -279,8 +290,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
vllm_config, device)
|
||||
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
self.model_config.max_model_len, self.dtype)
|
||||
|
||||
@@ -412,6 +422,73 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.async_output_copy_stream = torch.npu.Stream() if \
|
||||
self.use_async_scheduling else None
|
||||
# Input Batch
|
||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||
# `initialize_kv_cache` based on the kv cache config. However, as in
|
||||
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
|
||||
# reasons, we have to initialize the input batch before `load_model`,
|
||||
# quantization + weight offloading will fail otherwise. As a temporary
|
||||
# solution, we initialize the input batch here, and re-initialize it
|
||||
# in `initialize_kv_cache` if the block_sizes here is different from
|
||||
# the block_sizes in the kv cache config.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.model_config.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=[self.block_size],
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=build_logitsprocs(
|
||||
self.vllm_config, self.device, self.pin_memory,
|
||||
self.is_pooling_model,
|
||||
self.vllm_config.model_config.logits_processors),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
kernel_block_sizes=None,
|
||||
)
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
|
||||
def _make_buffer(self,
|
||||
*size: Union[int, torch.SymInt],
|
||||
dtype: torch.dtype,
|
||||
numpy: bool = True) -> CpuGpuBuffer:
|
||||
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
|
||||
# if a bfloat16 buffer is needed without a corresponding numpy array,
|
||||
# don't bother instantiating the numpy array.
|
||||
return CpuGpuBuffer(*size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
with_numpy=numpy)
|
||||
|
||||
def _update_states_after_model_execute(
|
||||
self, output_token_ids: torch.Tensor) -> None:
|
||||
"""Update the cached states after model execution.
|
||||
|
||||
This is used for MTP/EAGLE for hybrid models, as in linear attention,
|
||||
only the last token's state is kept. In MTP/EAGLE, for draft tokens
|
||||
the state are kept util we decide how many tokens are accepted for
|
||||
each sequence, and a shifting is done during the next iteration
|
||||
based on the number of accepted tokens.
|
||||
"""
|
||||
if not self.model_config.is_hybrid or not self.speculative_config:
|
||||
return
|
||||
|
||||
# Find the number of accepted tokens for each sequence.
|
||||
num_accepted_tokens = (torch.cat(
|
||||
[
|
||||
output_token_ids,
|
||||
torch.full((output_token_ids.size(0), 1),
|
||||
-1,
|
||||
device=output_token_ids.device),
|
||||
],
|
||||
dim=1) == -1).int().argmax(-1).cpu().numpy()
|
||||
for i, num_tokens in enumerate(num_accepted_tokens):
|
||||
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||
@@ -611,7 +688,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Condense the batched states if there are gaps left by removed requests
|
||||
self.input_batch.condense()
|
||||
|
||||
# Allow attention backend to reorder the batch, potentially
|
||||
self._may_reorder_batch(scheduler_output)
|
||||
# Refresh batch metadata with any pending updates.
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
@@ -970,22 +1048,42 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
src=self.input_batch.prev_sampled_token_ids[
|
||||
prev_common_req_indices_tensor, 0])
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
want to separate requests based on if the attention computation will be
|
||||
compute-bound or memory-bound.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
"""
|
||||
# Attention free models have zero kv_cache_goups, however models
|
||||
# like Mamba are also attention free but use the kv_cache for
|
||||
# keeping its internal state. This is why we check the number
|
||||
# of kv_cache groups instead of solely checking
|
||||
# for self.model_config.is_attention_free.
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return
|
||||
|
||||
if self.reorder_batch_threshold is not None:
|
||||
reorder_batch_to_split_decodes_and_prefills(
|
||||
self.input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata,
|
||||
AscendMLATorchairMetadata], torch.Tensor, np.ndarray, int,
|
||||
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
|
||||
Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
|
||||
int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
self.attn_metadata_builder.reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit_block_table(num_reqs)
|
||||
@@ -1088,9 +1186,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(
|
||||
total_num_scheduled_tokens)
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
|
||||
self.input_batch.block_table[0].
|
||||
slot_mapping_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
@@ -1131,32 +1226,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.with_prefill = with_prefill
|
||||
self.num_tokens_across_dp = num_tokens_across_dp
|
||||
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
|
||||
|
||||
# Make AscendCommonAttentionMetadata
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
block_table_tensor=self.input_batch.block_table[0].
|
||||
get_device_tensor(),
|
||||
slot_mapping_cpu=self.slot_mapping_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
enable_dbo_across_dp=enable_dbo,
|
||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
graph_pad_size=self.graph_pad_size,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
common_attn_metadata, self.model)
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
attn_metadata.num_input_tokens = num_input_tokens
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
|
||||
# Prepare input_ids
|
||||
token_indices = (positions_np +
|
||||
@@ -1238,6 +1308,90 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
||||
self.num_draft_tokens.np[num_reqs:].fill(0)
|
||||
self.num_draft_tokens.copy_to_gpu()
|
||||
|
||||
# Used in the below loop.
|
||||
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
spec_decode_common_attn_metadata = None
|
||||
if use_spec_decode:
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()
|
||||
slot_mapping = blk_table.slot_mapping_cpu[:
|
||||
total_num_scheduled_tokens]
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
|
||||
slot_mapping)
|
||||
# # Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# # graph mode.
|
||||
# blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
|
||||
# Make AscendCommonAttentionMetadata
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
# TODO: change this to the right block table for linear attn
|
||||
block_table_tensor=blk_table_tensor[:num_reqs],
|
||||
slot_mapping_cpu=self.slot_mapping_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
positions=self.positions,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
enable_dbo_across_dp=enable_dbo,
|
||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
graph_pad_size=self.graph_pad_size,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
|
||||
if self.speculative_config and \
|
||||
spec_decode_common_attn_metadata is None:
|
||||
spec_decode_common_attn_metadata = common_attn_metadata
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
common_prefix_len = 0
|
||||
extra_attn_metadata_args = {}
|
||||
builder = attn_group.metadata_builder
|
||||
if isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
if use_spec_decode:
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.
|
||||
gpu[:num_reqs],
|
||||
num_draft_tokens=self.num_draft_tokens.
|
||||
gpu[:num_reqs],
|
||||
)
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
**extra_attn_metadata_args)
|
||||
else:
|
||||
attn_metadata_i = builder.build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
model=self.model,
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
attn_metadata_i.num_input_tokens = num_input_tokens
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
if lmhead_tp_enable():
|
||||
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
|
||||
@@ -1453,9 +1607,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata,
|
||||
AscendMLATorchairMetadata],
|
||||
attn_metadata: dict[str, Any],
|
||||
aux_hidden_states: torch.Tensor = None,
|
||||
) -> Optional[list[list[int]]]:
|
||||
if not self.drafter:
|
||||
@@ -1700,6 +1852,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
|
||||
discard_sampled_tokens_req_indices: list[int] = []
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
@@ -2231,31 +2384,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
data_ptr = tensor.data_ptr()
|
||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
if self.model_config.is_deepseek_mla:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek(
|
||||
kv_cache_config)
|
||||
else:
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.model_config.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=[self.block_size],
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=build_logitsprocs(
|
||||
self.vllm_config, self.device, self.pin_memory,
|
||||
self.is_pooling_model,
|
||||
self.vllm_config.model_config.logits_processors),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
)
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
def initialize_kv_cache_tensors_deepseek(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
kv_cache_sizes = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
@@ -2263,12 +2407,141 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
"NPU.")
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
data_ptr = tensor.data_ptr()
|
||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
|
||||
):
|
||||
attn_backend = kv_cache_group.backend
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
assert tensor_size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
elif hasattr(attn_backend, "get_supported_block_size"
|
||||
) and not self.model_config.is_deepseek_mla:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk, block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
alignment = 2 * 1024 * 1024
|
||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
nope_dim = head_size - rope_dim
|
||||
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
nope_dim)
|
||||
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
rope_dim)
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
# For no disaggregate pd scenario, allocate kv cache in normal way
|
||||
rope_cache = torch.zeros(rope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = torch.zeros(nope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = self._convert_torch_format(rope_cache)
|
||||
nope_cache = self._convert_torch_format(nope_cache)
|
||||
else:
|
||||
|
||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
||||
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
||||
# of code may consume 2M * 2 * elem_size memory every layer.
|
||||
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
||||
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
||||
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
||||
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
||||
|
||||
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = align_memory(
|
||||
nope_cache,
|
||||
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
||||
rope_cache = align_memory(
|
||||
rope_cache,
|
||||
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
||||
kv_caches[layer_name] = (nope_cache, rope_cache)
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize the memory buffer for KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
# init kv cache tensors
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
# TODO: REFACTOR ME to sharing hybrid cache
|
||||
for idx in range(len(kv_cache_tensor.shared_by)):
|
||||
layer_name = kv_cache_tensor.shared_by[idx]
|
||||
if "linear_attn" in layer_name:
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
if "self_attn" in layer_name_inner or layer_name_inner in kv_cache_raw_tensors.keys(
|
||||
):
|
||||
continue
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
kv_cache_raw_tensors[layer_name_inner] = tensor
|
||||
elif "self_attn" in layer_name:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
|
||||
):
|
||||
attn_backend = kv_cache_group.backend
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel(
|
||||
) // kv_cache_spec.page_size_bytes
|
||||
|
||||
# `num_blocks` is the number of blocks the model runner can use.
|
||||
# `kv_cache_config.num_blocks` is the number of blocks that
|
||||
@@ -2278,100 +2551,228 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
alignment = 2 * 1024 * 1024
|
||||
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape(
|
||||
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
elif hasattr(attn_backend, "get_supported_block_size"
|
||||
) and not self.model_config.is_deepseek_mla:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk, block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
if self.model_config.is_deepseek_mla:
|
||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
nope_dim = head_size - rope_dim
|
||||
nope_cache_shape = (num_blocks, block_size,
|
||||
num_kv_heads, nope_dim)
|
||||
rope_cache_shape = (num_blocks, block_size,
|
||||
num_kv_heads, rope_dim)
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
# For no disaggregate pd scenario, allocate kv cache in normal way
|
||||
rope_cache = torch.zeros(rope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = torch.zeros(nope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = self._convert_torch_format(rope_cache)
|
||||
nope_cache = self._convert_torch_format(nope_cache)
|
||||
else:
|
||||
|
||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
||||
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
||||
# of code may consume 2M * 2 * elem_size memory every layer.
|
||||
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
||||
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
||||
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
||||
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
||||
|
||||
nope_cache = torch.zeros(
|
||||
nope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = torch.zeros(
|
||||
rope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = align_memory(
|
||||
nope_cache,
|
||||
alignment)[:nope_allocate_shape].view(
|
||||
nope_cache_shape)
|
||||
rope_cache = align_memory(
|
||||
rope_cache,
|
||||
alignment)[:rope_allocate_shape].view(
|
||||
rope_cache_shape)
|
||||
kv_caches[layer_name] = (nope_cache, rope_cache)
|
||||
else:
|
||||
num_caches = kv_cache_shape[0]
|
||||
kv_cache_list = []
|
||||
for i in range(num_caches):
|
||||
cache_shape = kv_cache_shape[1:]
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
kv_cache = torch.zeros(cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_cache = self._convert_torch_format(kv_cache)
|
||||
else:
|
||||
cache_size = math.prod(cache_shape)
|
||||
cache_size_aligned = cache_size + alignment
|
||||
kv_cache = torch.zeros(cache_size_aligned,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_cache = align_memory(
|
||||
kv_cache,
|
||||
alignment)[:cache_size].view(cache_shape)
|
||||
kv_cache_list.append(kv_cache)
|
||||
kv_caches[layer_name] = tuple(kv_cache_list)
|
||||
kv_cache = raw_tensor.view(dtype).view(kv_cache_shape)
|
||||
kv_cache = self._convert_torch_format(kv_cache)
|
||||
kv_caches[layer_name] = kv_cache
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
state_tensors = []
|
||||
storage_offset_bytes = 0
|
||||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||
kv_cache_spec.dtypes):
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
num_element_per_page = (
|
||||
kv_cache_spec.page_size_bytes // dtype_size)
|
||||
target_shape = (num_blocks, *shape)
|
||||
stride = torch.empty(target_shape).stride()
|
||||
target_stride = (num_element_per_page, *stride[1:])
|
||||
assert storage_offset_bytes % dtype_size == 0
|
||||
tensor = torch.as_strided(
|
||||
raw_tensor.view(dtype),
|
||||
size=target_shape,
|
||||
stride=target_stride,
|
||||
storage_offset=storage_offset_bytes // dtype_size,
|
||||
)
|
||||
state_tensors.append(tensor)
|
||||
storage_offset_bytes += stride[0] * dtype_size
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def _kv_cache_spec_attn_group_iterator(
|
||||
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
|
||||
if not self.kv_cache_config.kv_cache_groups:
|
||||
return
|
||||
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
|
||||
for attn_group in attn_groups:
|
||||
yield self.kv_cache_config.kv_cache_groups[
|
||||
kv_cache_spec_id].kv_cache_spec, attn_group
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Re-initialize the input batch if the block sizes are different from
|
||||
`[self.cache_config.block_size]`. This usually happens when there
|
||||
are multiple KV cache groups.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
"""
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
|
||||
# Generate kernel_block_sizes that matches each block_size
|
||||
# For attention backends that support virtual block splitting,
|
||||
# use the supported block sizes from the backend
|
||||
# For other backends (like Mamba), use [0] (no splitting)
|
||||
kernel_block_sizes = []
|
||||
for kv_cache_group_id, kv_cache_group in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
|
||||
# This is an attention backend that supports virtual
|
||||
# block splitting. Get the supported block sizes from
|
||||
# the backend.
|
||||
try:
|
||||
attn_groups = self.attn_groups[kv_cache_group_id]
|
||||
except IndexError:
|
||||
attn_groups = None
|
||||
if attn_groups:
|
||||
# Use the backend's supported block size list
|
||||
backend = attn_groups[0].backend
|
||||
supported_sizes = backend.get_supported_block_size()
|
||||
# If no specific sizes supported, use cache config
|
||||
# block_size
|
||||
kernel_block_size_list = (supported_sizes
|
||||
if supported_sizes else
|
||||
[self.cache_config.block_size])
|
||||
else:
|
||||
# Fallback to cache config block_size if no backend found
|
||||
kernel_block_size_list = [
|
||||
64
|
||||
] if not self.model_config.is_deepseek_mla else [0]
|
||||
kernel_block_sizes.append(kernel_block_size_list)
|
||||
else:
|
||||
# This is likely Mamba or other non-attention cache,
|
||||
# no splitting.
|
||||
kernel_block_sizes.append([0])
|
||||
if kernel_block_sizes != [self.cache_config.block_size]:
|
||||
assert self.cache_config.cpu_offload_gb == 0, (
|
||||
"Cannot re-initialize the input batch when CPU weight "
|
||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||
"for more details.")
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.model_config.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=(
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config else 0),
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
)
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the attention backends and attention metadata builders.
|
||||
"""
|
||||
assert len(self.attn_groups) == 0, \
|
||||
"Attention backends are already initialized"
|
||||
|
||||
def get_attn_backends_for_layers(
|
||||
layer_names: list[str]
|
||||
) -> dict[type[AttentionBackend], list[str]]:
|
||||
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase,
|
||||
layer_names)
|
||||
attn_backends = {}
|
||||
attn_backend_layers = defaultdict(list)
|
||||
# Dedupe based on full class name; this is a bit safer than
|
||||
# using the class itself as the key because when we create dynamic
|
||||
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
|
||||
# they are cached correctly, there will be different objects per
|
||||
# layer.
|
||||
for layer_name in layer_names:
|
||||
attn_backend = layers[layer_name].get_attn_backend()
|
||||
key = attn_backend.full_cls_name()
|
||||
attn_backends[key] = attn_backend
|
||||
attn_backend_layers[key].append(layer_name)
|
||||
return {
|
||||
attn_backends[k]: v
|
||||
for k, v in attn_backend_layers.items()
|
||||
}
|
||||
|
||||
def create_attn_groups(
|
||||
attn_backends_map: dict[AttentionBackend, list[str]],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
) -> list[AttentionGroup]:
|
||||
attn_groups: list[AttentionGroup] = []
|
||||
for attn_backend, layer_names in attn_backends_map.items():
|
||||
attn_metadata_builder_i = attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
attn_group = AttentionGroup(attn_backend,
|
||||
attn_metadata_builder_i,
|
||||
layer_names)
|
||||
attn_groups.append(attn_group)
|
||||
return attn_groups
|
||||
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
attn_backends = get_attn_backends_for_layers(
|
||||
kv_cache_group_spec.layer_names)
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, kv_cache_spec))
|
||||
|
||||
# Calculate reorder batch threshold (if needed)
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
||||
return itertools.chain.from_iterable(self.attn_groups)
|
||||
|
||||
def calculate_reorder_batch_threshold(self) -> None:
|
||||
"""
|
||||
Check that if any backends reorder batches; that the reordering
|
||||
is compatible (e.g., decode threshold is the same)
|
||||
"""
|
||||
for group in self._attn_group_iterator():
|
||||
attn_metadata_builder_i = group.metadata_builder
|
||||
if hasattr(attn_metadata_builder_i, "reorder_batch_threshold"):
|
||||
# check that if any backends reorder batches; that the reordering
|
||||
# is compatible (e.g., decode threshold is the same)
|
||||
reorder_batch_threshold_i = (
|
||||
attn_metadata_builder_i.reorder_batch_threshold)
|
||||
if reorder_batch_threshold_i is not None:
|
||||
if self.reorder_batch_threshold is not None:
|
||||
if reorder_batch_threshold_i != \
|
||||
self.reorder_batch_threshold:
|
||||
raise ValueError(
|
||||
f"Attention backend reorders decodes with "
|
||||
f"threshold {reorder_batch_threshold_i} but other "
|
||||
f"backend uses threshold "
|
||||
f"{self.reorder_batch_threshold}")
|
||||
else:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold_i
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
@@ -2382,19 +2783,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
forward_ctx = self.compilation_config.static_forward_context
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention
|
||||
assert isinstance(attn_module, Attention)
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=self.block_size,
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
@@ -2409,6 +2820,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (
|
||||
self.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def initialize_aclgraph_capture(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user