[ModelRunner][Qwen3-Next] Fix attn_group initialization timing (#3477)
### What this PR does / why we need it? Fix attn_group initialization timing so that fix qwen3-next model ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -2694,6 +2694,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"""
|
"""
|
||||||
kv_cache_config = deepcopy(kv_cache_config)
|
kv_cache_config = deepcopy(kv_cache_config)
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
|
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||||
|
# NOTE(cmq): initialize_attn_backend must before using self.attn_groups
|
||||||
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
||||||
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
||||||
self.need_accepted_tokens = any([
|
self.need_accepted_tokens = any([
|
||||||
@@ -2702,8 +2705,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
])
|
])
|
||||||
|
|
||||||
self.may_reinitialize_input_batch(kv_cache_config)
|
self.may_reinitialize_input_batch(kv_cache_config)
|
||||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
|
||||||
|
|
||||||
if self.use_sparse:
|
if self.use_sparse:
|
||||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
||||||
@@ -3100,6 +3101,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
block_sizes = [
|
block_sizes = [
|
||||||
kv_cache_group.kv_cache_spec.block_size
|
kv_cache_group.kv_cache_spec.block_size
|
||||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||||
|
if not isinstance(kv_cache_group.kv_cache_spec,
|
||||||
|
EncoderOnlyAttentionSpec)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Generate kernel_block_sizes that matches each block_size
|
# Generate kernel_block_sizes that matches each block_size
|
||||||
@@ -3109,7 +3112,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kernel_block_sizes = []
|
kernel_block_sizes = []
|
||||||
for kv_cache_group_id, kv_cache_group in enumerate(
|
for kv_cache_group_id, kv_cache_group in enumerate(
|
||||||
kv_cache_config.kv_cache_groups):
|
kv_cache_config.kv_cache_groups):
|
||||||
if isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
|
|
||||||
|
if isinstance(kv_cache_group.kv_cache_spec,
|
||||||
|
EncoderOnlyAttentionSpec):
|
||||||
|
continue
|
||||||
|
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
|
||||||
# This is an attention backend that supports virtual
|
# This is an attention backend that supports virtual
|
||||||
# block splitting. Get the supported block sizes from
|
# block splitting. Get the supported block sizes from
|
||||||
# the backend.
|
# the backend.
|
||||||
@@ -3137,7 +3144,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# of mamba block. In this case, BlockTable.block_size will never equal
|
# of mamba block. In this case, BlockTable.block_size will never equal
|
||||||
# to kernel_block_sizes[0]
|
# to kernel_block_sizes[0]
|
||||||
kernel_block_sizes.append([0])
|
kernel_block_sizes.append([0])
|
||||||
if kernel_block_sizes != [[self.cache_config.block_size]]:
|
|
||||||
|
if block_sizes != [
|
||||||
|
self.cache_config.block_size
|
||||||
|
] or kernel_block_sizes != [self.cache_config.block_size]:
|
||||||
assert self.cache_config.cpu_offload_gb == 0, (
|
assert self.cache_config.cpu_offload_gb == 0, (
|
||||||
"Cannot re-initialize the input batch when CPU weight "
|
"Cannot re-initialize the input batch when CPU weight "
|
||||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||||
|
|||||||
Reference in New Issue
Block a user