Drop 0.11.0 support (#4377)
There is a lot hack code for v0.11.0, which makes the code hard to
upgrade to newer vLLM version. Since v0.11.0 will release soon. Let's
drop v0.11.0 support first. Then we'll upgrade to v0.11.2 soon.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -41,10 +41,11 @@ import torch.nn as nn
|
||||
from tqdm import tqdm # type: ignore
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layer import Attention, MLAAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
@@ -58,8 +59,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
||||
supports_mrope,
|
||||
supports_transcription)
|
||||
@@ -73,29 +72,23 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.utils import cdiv
|
||||
else:
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, MLAAttentionSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
# yapf: enable
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, LogprobsTensors, ModelRunnerOutput,
|
||||
PoolerOutput)
|
||||
@@ -119,6 +112,7 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
AscendPrefillContextParallelMetadata)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
set_graph_params,
|
||||
@@ -147,8 +141,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_soc_version, is_310p,
|
||||
is_enable_nz, is_moe_model, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable,
|
||||
vllm_version_is)
|
||||
prefill_context_parallel_enable)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
@@ -157,27 +150,6 @@ if prefill_context_parallel_enable():
|
||||
get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size)
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
get_dtype_size)
|
||||
else:
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||
|
||||
# yapf: enable
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.utils import LazyLoader, is_pin_memory_available
|
||||
|
||||
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||
else:
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -637,11 +609,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
diagonal=1).to(self.device)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.drafter = self._get_drafter()
|
||||
if vllm_version_is("0.11.0"):
|
||||
self.rejection_sampler = AscendRejectionSampler()
|
||||
else:
|
||||
self.rejection_sampler = AscendRejectionSampler(
|
||||
self.sampler)
|
||||
self.rejection_sampler = AscendRejectionSampler(self.sampler)
|
||||
self.actual_seq_lengths_q = list(
|
||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||
self.decode_token_per_req))
|
||||
@@ -664,11 +632,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# tokens is less than or equal to mc2_tokens_capacity. According to _set_cudagraph_sizes,
|
||||
# the max number of tokens in graph is min(max_num_seqs * uniform_decode_query_len, 512).
|
||||
if self.compilation_config.cudagraph_capture_sizes:
|
||||
if vllm_version_is("0.11.0"):
|
||||
max_num_tokens = self.compilation_config.cudagraph_capture_sizes[
|
||||
0]
|
||||
else:
|
||||
max_num_tokens = self.compilation_config.max_cudagraph_capture_size
|
||||
max_num_tokens = self.compilation_config.max_cudagraph_capture_size
|
||||
else:
|
||||
# NOTE: To save memory, we cap the max number of tokens to 512.
|
||||
max_num_tokens = min(
|
||||
@@ -717,10 +681,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
if vllm_version_is("0.11.0"):
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||
else:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove finished requests from the cached states.
|
||||
@@ -914,9 +875,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
if supports_mrope(self.model):
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
self.model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
@@ -925,18 +886,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
else:
|
||||
if supports_mrope(self.model):
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
self.model.get_mrope_input_positions(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
|
||||
def _sync_metadata_across_dp(
|
||||
self, num_tokens: int,
|
||||
@@ -1108,21 +1057,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
|
||||
scheduler_output)
|
||||
encoder_outputs = []
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
mm_inputs = group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
else:
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
mm_inputs = group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
)
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
mm_inputs = group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
)
|
||||
for modality, num_items, mm_kwargs_group in mm_inputs:
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
@@ -1181,56 +1122,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return mm_kwargs, mm_hashes_pos
|
||||
|
||||
def _gather_mm_embeddings_0110(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> list[torch.Tensor]:
|
||||
|
||||
def _iter_mm_features(req_state: CachedRequestState):
|
||||
assert req_state.mm_features is not None
|
||||
for mm_feature in req_state.mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
yield mm_feature.identifier, pos_info, getattr(
|
||||
pos_info, "is_embed", None)
|
||||
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
|
||||
for req_id in self.input_batch.req_ids:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
|
||||
for mm_hash, pos_info, is_embed in _iter_mm_features(req_state):
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
if start_pos >= num_computed_tokens + num_scheduled_tokens:
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= num_computed_tokens:
|
||||
continue
|
||||
|
||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None, \
|
||||
f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if is_embed is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
return mm_embeds
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -1730,22 +1621,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
input_ids = self.input_ids[:total_num_scheduled_tokens]
|
||||
if vllm_version_is("0.11.0"):
|
||||
mm_embeds = self._gather_mm_embeddings_0110(scheduler_output)
|
||||
if mm_embeds:
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids, mm_embeds)
|
||||
else:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
else:
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
|
||||
scheduler_output)
|
||||
mm_embeds, is_mm_embed = self._gather_mm_embeddings(
|
||||
scheduler_output)
|
||||
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
inputs_embeds = self.model.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=mm_embeds,
|
||||
is_multimodal=is_mm_embed,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Avoid the copy. Optimize.
|
||||
self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_(
|
||||
@@ -2151,9 +2034,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO: Optimize the CPU -> NPU copy.
|
||||
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
self.device, non_blocking=True)
|
||||
if not vllm_version_is("0.11.0"):
|
||||
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
|
||||
self.device, non_blocking=True)
|
||||
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
|
||||
self.device, non_blocking=True)
|
||||
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
||||
non_blocking=True)
|
||||
target_logits_indices = torch.from_numpy(target_logits_indices).to(
|
||||
@@ -2167,25 +2049,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
if self.pcp_size > 1:
|
||||
logits_indices = logits_indices_pcp
|
||||
if vllm_version_is("0.11.0"):
|
||||
metadata = SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
else:
|
||||
metadata = SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
cu_num_sampled_tokens=cu_num_sampled_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
metadata = SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
cu_num_sampled_tokens=cu_num_sampled_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
return metadata
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
@@ -2222,33 +2094,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
shape=(logits.shape[0],
|
||||
grammar_bitmask.shape[1]))
|
||||
cumulative_index = 0
|
||||
if vllm_version_is("0.11.0"):
|
||||
seq = sorted(
|
||||
scheduler_output.structured_output_request_ids.items(),
|
||||
key=lambda x: x[1])
|
||||
for req_id, _ in seq:
|
||||
for req_id in scheduler_output.structured_output_request_ids:
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
if req_id in struct_out_req_batch_indices:
|
||||
logit_index = struct_out_req_batch_indices[req_id]
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, []))
|
||||
for i in range(1 + num_spec_tokens):
|
||||
sorted_bitmask[logit_index + i] = \
|
||||
grammar_bitmask[cumulative_index + i]
|
||||
sorted_bitmask[logit_index +
|
||||
i] = grammar_bitmask[cumulative_index + i]
|
||||
out_indices.append(logit_index + i)
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
else:
|
||||
for req_id in scheduler_output.structured_output_request_ids:
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, []))
|
||||
if req_id in struct_out_req_batch_indices:
|
||||
logit_index = struct_out_req_batch_indices[req_id]
|
||||
for i in range(1 + num_spec_tokens):
|
||||
sorted_bitmask[logit_index +
|
||||
i] = grammar_bitmask[cumulative_index +
|
||||
i]
|
||||
out_indices.append(logit_index + i)
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
cumulative_index += 1 + num_spec_tokens
|
||||
grammar_bitmask = sorted_bitmask
|
||||
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
@@ -2518,14 +2373,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if vllm_version_is("0.11.0"):
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
logits = self.apply_grammar_bitmask(
|
||||
scheduler_output, logits)
|
||||
else:
|
||||
if scheduler_output.structured_output_request_ids:
|
||||
logits = self.apply_grammar_bitmask(
|
||||
scheduler_output, logits)
|
||||
if scheduler_output.structured_output_request_ids:
|
||||
logits = self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
with ProfileExecuteDuration().capture_async("Sample"):
|
||||
# Sample the next token and get logprobs if needed.
|
||||
@@ -3837,95 +3686,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold_i
|
||||
|
||||
def get_kv_cache_spec_v0110(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
use_sparse = self.use_sparse
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
if isinstance(attn_module, AscendMultiHeadLatentAttention):
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if use_mla and not use_sparse:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
cache_dtype_str=self.cache_config.cache_dtype)
|
||||
else:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (
|
||||
self.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
@@ -3934,9 +3694,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
if vllm_version_is("0.11.0"):
|
||||
return self.get_kv_cache_spec_v0110()
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
|
||||
Reference in New Issue
Block a user