### What this PR does / why we need it?
https://github.com/vllm-project/vllm/pull/35122 This PR in the vllm
community refactors the update mode of block_size. As a result, when the
user does not specify `--block-size`, dsv3.2 obtains an incorrect
block_size.
**The root cause of the problem is analyzed from the block_size update
process as follows:**
1. In NPUPlatform, `check_and_update_config` calls `refresh_block_size`
to set block_size to 128.
2. During Modelrunner initialization, the `self.block_size` parameter is
generated. At this time, block_size is still 128. This parameter will be
used for operations such as kvcache initialization.
3. `update_block_size_for_backend` updates block_size to the size set in
attn_backend. The reason why the DSV3.2 is faulty is that it has an
additional attn_backend `DeepseekV32IndexerBackend`, and this backend is
not rewritten. The block_size obtained from attn_backend is 64. In this
case, only `vllm_config.cache_config.block_size` is updated, and other
parts are not modified. As a result, the block_size on the entire
network is inconsistent.
**Modification solution:**
Skip `update_block_size_for_backend` and modify block_size only in the
`check_and_update_config` method.
In the future, the block_size update logic can be migrated to the
`update_block_size_for_backend` method. Ensure that all block_size
values on the entire network are updated.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.18.0
- vLLM main:
ed359c497a
---------
<!-- Thanks for sending a pull request!
BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html
-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.
- Please clarify why the changes are needed. For instance, the use case
and bug description.
- Fixes #
-->
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -224,17 +224,9 @@ class NPUPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def update_block_size_for_backend(cls, vllm_config: VllmConfig) -> None:
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.user_specified_block_size:
|
||||
# User specified --block-size; keep it.
|
||||
return
|
||||
model_config = vllm_config.model_config
|
||||
if model_config is not None and model_config.is_hybrid:
|
||||
# Hybrid attention+mamba models rely on the model-specific sizing
|
||||
# logic rather than the generic platform default.
|
||||
return
|
||||
|
||||
super().update_block_size_for_backend(vllm_config)
|
||||
# TODO: NPU still sets block_size in check_and_update_config.
|
||||
# Move that logic here so block_size is chosen by the backend.
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device):
|
||||
|
||||
@@ -1101,12 +1101,12 @@ def refresh_block_size(vllm_config):
|
||||
if not scheduler_config or not model_config:
|
||||
return
|
||||
|
||||
# TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups.
|
||||
if (
|
||||
"qwen3_next" not in model_config.hf_text_config.model_type
|
||||
and "qwen3_5" not in model_config.hf_text_config.model_type
|
||||
and cache_config.block_size != 128
|
||||
):
|
||||
if model_config.is_hybrid:
|
||||
# Hybrid attention+mamba models rely on the model-specific sizing
|
||||
# logic rather than the generic platform default.
|
||||
return
|
||||
|
||||
if cache_config.block_size != 128:
|
||||
if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill:
|
||||
logger.info("Block size is set to 128 if prefix cache or chunked prefill is enabled.")
|
||||
cache_config.block_size = 128
|
||||
|
||||
Reference in New Issue
Block a user