[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with17c540a9931. refactor deepseek to the latest code arch as of17c540a9932. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -44,8 +44,7 @@ from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
@@ -59,18 +58,22 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import supports_transcription
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
||||
supports_mrope,
|
||||
supports_transcription)
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LazyLoader, cdiv, get_dtype_size,
|
||||
is_pin_memory_available)
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
||||
get_dtype_size, is_pin_memory_available)
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@@ -92,7 +95,6 @@ from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@@ -120,7 +122,6 @@ from vllm_ascend.eplb.core.eplb_utils import EPLBParamUtils
|
||||
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||
from vllm_ascend.eplb.utils import model_register
|
||||
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@@ -134,7 +135,8 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_soc_version, is_310p,
|
||||
is_enable_nz, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable)
|
||||
prefill_context_parallel_enable,
|
||||
vllm_version_is)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
@@ -143,6 +145,19 @@ if prefill_context_parallel_enable():
|
||||
get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.utils import LazyLoader
|
||||
|
||||
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||
else:
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -556,6 +571,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int64)
|
||||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
# Only relevant for multimodal models
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config)
|
||||
if self.supports_mm_inputs:
|
||||
self.is_mm_embed = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.bool)
|
||||
# TODO: EVS Support (Video tokens pruning) (see vllm#22980)
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
|
||||
def _may_pad_kv_consumer_num_seq(self):
|
||||
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
|
||||
@@ -615,7 +639,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||
if vllm_version_is("0.11.0"):
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||
else:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove finished requests from the cached states.
|
||||
@@ -807,16 +834,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
if vllm_version_is("0.11.0"):
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
else:
|
||||
if supports_mrope(self.model):
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
self.model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
@@ -1007,11 +1047,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output)
|
||||
encoder_outputs = []
|
||||
|
||||
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
if vllm_version_is("0.11.0"):
|
||||
mm_inputs = group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=True,
|
||||
):
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
else:
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
mm_inputs = group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
)
|
||||
for modality, num_items, mm_kwargs_group in mm_inputs:
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
||||
@@ -1069,7 +1119,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return mm_kwargs, mm_hashes_pos
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
def _gather_mm_embeddings_0110(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> list[torch.Tensor]:
|
||||
@@ -1119,6 +1169,77 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
return mm_embeds
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
shift_computed_tokens: int = 0,
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
mm_embeds = list[torch.Tensor]()
|
||||
is_mm_embed = self.is_mm_embed.cpu
|
||||
is_mm_embed[:total_num_scheduled_tokens] = False
|
||||
|
||||
req_start_idx = 0
|
||||
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mm_embeds_req: list[torch.Tensor] = []
|
||||
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = \
|
||||
req_state.num_computed_tokens + shift_computed_tokens
|
||||
|
||||
for mm_feature in req_state.mm_features: # type: ignore
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens,
|
||||
# num_computed_tokens + num_scheduled_tokens) and
|
||||
# [start_pos, start_pos + num_encoder_tokens)
|
||||
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
||||
# The encoder output is not needed in this step.
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None,\
|
||||
f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
|
||||
= True if is_embed is None else is_embed
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds_req.append(mm_embeds_item)
|
||||
|
||||
mm_embeds.extend(mm_embeds_req)
|
||||
req_start_idx += num_scheduled_tokens
|
||||
|
||||
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
|
||||
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
num_tokens: np.ndarray,
|
||||
@@ -1429,17 +1550,28 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.is_multimodal_model:
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, mm_embeds)
|
||||
if vllm_version_is("0.11.0"):
|
||||
mm_embeds = self._gather_mm_embeddings_0110(scheduler_output)
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, mm_embeds)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
|
||||
scheduler_output)
|
||||
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
|
||||
inputs_embeds)
|
||||
@@ -1780,6 +1912,86 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
return metadata
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||
|
||||
# We receive the structured output bitmask from the scheduler,
|
||||
# compacted to contain bitmasks only for structured output requests.
|
||||
# The order of the requests in the bitmask is not guaranteed to be the
|
||||
# same as the order of the requests in the gpu runner's batch. We need
|
||||
# to sort the bitmask to match the order of the requests used here.
|
||||
|
||||
# Get the batch indices of the structured output requests.
|
||||
# Keep track of the number of speculative tokens scheduled for every
|
||||
# request in the batch, as the logit indices are offset by this amount.
|
||||
struct_out_req_batch_indices: dict[str, int] = {}
|
||||
cumulative_offset = 0
|
||||
seq = sorted(self.input_batch.req_id_to_index.items(),
|
||||
key=lambda x: x[1])
|
||||
for req_id, batch_index in seq:
|
||||
logit_index = batch_index + cumulative_offset
|
||||
cumulative_offset += len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
if req_id in scheduler_output.structured_output_request_ids:
|
||||
struct_out_req_batch_indices[req_id] = logit_index
|
||||
|
||||
out_indices = []
|
||||
|
||||
# Reorder the bitmask to match the order of the requests in the batch.
|
||||
sorted_bitmask = np.zeros_like(grammar_bitmask,
|
||||
shape=(logits.shape[0],
|
||||
grammar_bitmask.shape[1]))
|
||||
cumulative_index = 0
|
||||
if vllm_version_is("0.11.0"):
|
||||
seq = sorted(
|
||||
scheduler_output.structured_output_request_ids.items(),
|
||||
key=lambda x: x[1])
|
||||
for req_id, _ in seq:
|
||||
logit_index = struct_out_req_batch_indices[req_id]
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, []))
|
||||
for i in range(1 + num_spec_tokens):
|
||||
sorted_bitmask[logit_index + i] = \
|
||||
grammar_bitmask[cumulative_index + i]
|
||||
out_indices.append(logit_index + i)
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
else:
|
||||
for req_id in scheduler_output.structured_output_request_ids:
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, []))
|
||||
if req_id in struct_out_req_batch_indices:
|
||||
logit_index = struct_out_req_batch_indices[req_id]
|
||||
for i in range(1 + num_spec_tokens):
|
||||
sorted_bitmask[logit_index +
|
||||
i] = grammar_bitmask[cumulative_index +
|
||||
i]
|
||||
out_indices.append(logit_index + i)
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
grammar_bitmask = sorted_bitmask
|
||||
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||
|
||||
# NOTE:
|
||||
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
||||
# 2. The logits and bitmask should be on the same device.
|
||||
# 3. XGrammar logits on CPU only supports float32 dtype.
|
||||
logits_dtype = logits.dtype
|
||||
logits = logits.to("cpu").float()
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
logits,
|
||||
grammar_bitmask,
|
||||
indices=out_indices,
|
||||
)
|
||||
return logits.to(self.device).to(logits_dtype)
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
@@ -2027,17 +2239,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
assert logits is not None
|
||||
# NOTE:
|
||||
# 1. XGrammar bitmask applying only supports CPU and GPU.
|
||||
# 2. The logits and bitmask should be on the same device.
|
||||
# 3. XGrammar logits on CPU only supports float32 dtype.
|
||||
logits_dtype = logits.dtype
|
||||
logits = logits.to("cpu").float()
|
||||
apply_grammar_bitmask(scheduler_output, self.input_batch,
|
||||
logits, torch.device("cpu"))
|
||||
logits = logits.to(self.device).to(logits_dtype)
|
||||
if vllm_version_is("0.11.0"):
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
logits = self.apply_grammar_bitmask(
|
||||
scheduler_output, logits)
|
||||
else:
|
||||
if scheduler_output.structured_output_request_ids:
|
||||
logits = self.apply_grammar_bitmask(
|
||||
scheduler_output, logits)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
@@ -3331,7 +3540,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold_i
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
def get_kv_cache_spec_v0110(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
@@ -3420,6 +3629,103 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
if vllm_version_is("0.11.0"):
|
||||
return self.get_kv_cache_spec_v0110()
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if isinstance(attn_module, Attention):
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
if use_mla and not self.use_sparse:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
cache_dtype_str=self.cache_config.cache_dtype)
|
||||
else:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype)
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (
|
||||
self.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def initialize_aclgraph_capture(self) -> None:
|
||||
min_ag_support = AttentionCGSupport.ALWAYS
|
||||
min_ag_builder_name = None
|
||||
|
||||
Reference in New Issue
Block a user