[Worker] Implement update max_model_len interface for NPUWorker (#6193)
### What this PR does / why we need it?
This patch purpose to add the `update_max_model_len` interface.
- vLLM version: v0.14.0
- vLLM main:
d68209402d
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -92,6 +92,7 @@ jobs:
|
||||
# We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run
|
||||
# the test separately.
|
||||
# basic
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/test_auto_fit_max_mode_len.py
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_accuracy.py
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/test_aclgraph_mem.py
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/test_async_scheduling.py
|
||||
|
||||
100
tests/e2e/singlecard/test_auto_fit_max_mode_len.py
Normal file
100
tests/e2e/singlecard/test_auto_fit_max_mode_len.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
|
||||
def new_kv_cache_spec(
|
||||
block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
page_size_padded=None,
|
||||
sliding_window=None,
|
||||
attention_chunk_size=None,
|
||||
):
|
||||
return FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
page_size_padded=page_size_padded,
|
||||
sliding_window=sliding_window,
|
||||
attention_chunk_size=attention_chunk_size,
|
||||
)
|
||||
|
||||
def test_auto_fit_max_model_len():
|
||||
"""Test that max_model_len=-1 auto-fits to available NPU memory."""
|
||||
# Create config with original_max_model_len=-1 to trigger auto-fit
|
||||
model_config = ModelConfig(max_model_len=1024)
|
||||
# Simulate the user passing -1 by setting original_max_model_len
|
||||
model_config.original_max_model_len = -1
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
||||
# block_size * 2 * head_size * num_kv_heads * dtype_size
|
||||
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2 # 16KB per block per layer
|
||||
kv_cache_specs = {
|
||||
"layer_1": new_kv_cache_spec(),
|
||||
"layer_2": new_kv_cache_spec(),
|
||||
}
|
||||
|
||||
# With enough memory, max_model_len stays at the derived max
|
||||
large_available_memory = mem_per_block_per_layer * 2 * 1024 # plenty of memory
|
||||
_kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs], [large_available_memory]
|
||||
)
|
||||
assert vllm_config.model_config.max_model_len == 1024
|
||||
|
||||
# Reset for next test
|
||||
model_config = ModelConfig(max_model_len=1024)
|
||||
model_config.original_max_model_len = -1
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
||||
# With limited memory, max_model_len should be reduced
|
||||
# Need memory for at least max_model_len tokens
|
||||
# 32 blocks worth of memory for 2 layers = can fit 32*16=512 tokens
|
||||
limited_memory = mem_per_block_per_layer * 2 * 32
|
||||
_kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs], [limited_memory]
|
||||
)
|
||||
# Should be reduced to fit in memory
|
||||
assert vllm_config.model_config.max_model_len < 1024
|
||||
assert vllm_config.model_config.max_model_len > 0
|
||||
|
||||
|
||||
def test_auto_fit_max_model_len_not_triggered():
|
||||
"""Test that auto-fit is not triggered when original_max_model_len is not -1."""
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
# original_max_model_len should be None by default, not -1
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
||||
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2
|
||||
kv_cache_specs = {
|
||||
"layer_1": new_kv_cache_spec(),
|
||||
"layer_2": new_kv_cache_spec(),
|
||||
}
|
||||
|
||||
# This should work normally without auto-fit
|
||||
_kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs], [mem_per_block_per_layer * 2 * 32]
|
||||
)
|
||||
assert vllm_config.model_config.max_model_len == 16
|
||||
@@ -426,6 +426,19 @@ class NPUWorker(WorkerBase):
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
"""Update max_model_len after auto-fit to NPU memory.
|
||||
|
||||
This is called when max_model_len=-1 is used and the engine
|
||||
automatically determines the maximum context length that fits
|
||||
in GPU memory. Workers need to update their cached max_model_len
|
||||
to match the engine's decision.
|
||||
"""
|
||||
self.model_config.max_model_len = max_model_len
|
||||
if self.model_runner is not None:
|
||||
self.model_runner.update_max_model_len(max_model_len)
|
||||
logger.debug("Updated max_model_len to %d", max_model_len)
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate NPU KV cache with the specified kv_cache_config."""
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
|
||||
Reference in New Issue
Block a user