[Worker][V1] Support sleep mode for v1 (#1084)

### What this PR does / why we need it?
 Support sleep mode for v1

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-06-06 21:54:02 +08:00
committed by GitHub
parent 0395ab30be
commit a2552e10e4
5 changed files with 65 additions and 60 deletions

View File

@@ -114,11 +114,13 @@ jobs:
# pytest -sv tests/singlecard/test_guided_decoding.py.py
# test_ascend_config.py should be ran separately because it will regenerate the global config many times.
pytest -sv tests/singlecard/test_ascend_config.py
pytest -sv tests/singlecard/test_camem.py
pytest -sv tests/singlecard/ \
--ignore=tests/singlecard/test_offline_inference.py \
--ignore=tests/singlecard/test_scheduler.py \
--ignore=tests/singlecard/test_guided_decoding.py \
--ignore=tests/singlecard/test_ascend_config.py
--ignore=tests/singlecard/test_ascend_config.py \
--ignore=tests/singlecard/test_camem.py
else
pytest -sv tests/multicard/test_ilama_lora_tp2.py
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py

View File

@@ -16,9 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.utils import GiB_bytes
@@ -26,9 +24,6 @@ from vllm.utils import GiB_bytes
from tests.utils import fork_new_process_for_each_test
from vllm_ascend.device_allocator.camem import CaMemAllocator
if os.getenv("VLLM_USE_V1") == "1":
pytest.skip("Skip in vllm v1", allow_module_level=True)
@fork_new_process_for_each_test
def test_basic_camem():

View File

@@ -15,6 +15,7 @@
# This file is a part of the vllm-ascend project.
#
import gc
import logging
import os
from typing import TYPE_CHECKING, Optional, Tuple
@@ -118,6 +119,12 @@ class NPUPlatform(Platform):
def mem_get_info(cls) -> Tuple[int, int]:
return torch.npu.mem_get_info()
@classmethod
def clear_npu_memory(cls):
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# initialize ascend config from vllm additional_config

View File

@@ -1235,11 +1235,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# assert self.lora_manager is not None, "LoRA is not enabled"
# TODO: call maybe_profile_with_lora()
dummy_kv_caches = [
torch.tensor((), dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
@@ -1250,7 +1245,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
logits = None
NPUPlatform.synchronize()
del hidden_states, logits, dummy_kv_caches
del hidden_states, logits
self.encoder_cache.clear()
gc.collect()

View File

@@ -17,8 +17,7 @@
# Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py
#
import gc
from typing import Dict, List, Optional
from typing import Optional
import torch
import torch.nn as nn
@@ -33,16 +32,15 @@ from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.worker_base import WorkerBase
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib
@@ -95,10 +93,22 @@ class NPUWorker(WorkerBase):
self.profiler = self._init_profiler()
def sleep(self, level: int = 1) -> None:
logger.error("Sleep mode is only supported on v0")
NPUPlatform.set_device(self.device)
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
allocator = CaMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = NPUPlatform.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
logger.error("Sleep mode is only supported on v0")
allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags)
def init_device(self):
if self.device_config.device.type == "npu":
@@ -119,58 +129,42 @@ class NPUWorker(WorkerBase):
self.model_runner = NPUModelRunner(self.vllm_config, self.device)
def determine_available_memory(self) -> int:
kv_caches: Dict[str, torch.Tensor] = {}
kv_cache_spec = self.model_runner.get_kv_cache_spec()
for layer_name, layer_spec in kv_cache_spec.items():
if isinstance(layer_spec, FullAttentionSpec):
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
npu_k_cache = torch.tensor([],
dtype=layer_spec.dtype,
device=self.device)
npu_v_cache = torch.tensor([],
dtype=layer_spec.dtype,
device=self.device)
kv_caches[layer_name] = (npu_k_cache, npu_v_cache)
else:
raise NotImplementedError
runner_kv_caches: List[torch.Tensor] = []
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
NPUPlatform.empty_cache()
NPUPlatform.clear_npu_memory()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
_, total_npu_memory = NPUPlatform.mem_get_info()
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info()
free_npu_memory, _ = NPUPlatform.mem_get_info()
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_npu_memory - free_npu_memory
assert peak_memory > 0, (
assert self.init_npu_memory > free_npu_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_npu_memory}, current free memory"
f" {free_npu_memory}. This happens when the NPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
gc.collect()
# Get the peak memory allocation recorded by torch
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
# TODO: don`t need impl this func after empty_cache in
# Worker.determine_num_available_blocks() unified`
NPUPlatform.empty_cache()
usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory
npu_kv_cache_bytes = max(usable_memory_size, 0)
logger.info(
f"Available memory: {usable_memory_size}, total memory: {total_npu_memory}"
)
return int(npu_kv_cache_bytes)
torch_allocated_bytes = torch_npu.npu.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch_npu.npu.mem_get_info(
)[1] - torch_npu.npu.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_npu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
return int(available_kv_cache_memory)
def execute_model(
self,
@@ -180,7 +174,17 @@ class NPUWorker(WorkerBase):
return output if self.is_driver_worker else None
def load_model(self) -> None:
self.model_runner.load_model()
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CaMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext() # type: ignore
with context:
self.model_runner.load_model()
def compile_or_warm_up_model(self) -> None:
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
@@ -206,12 +210,14 @@ class NPUWorker(WorkerBase):
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate NPU KV cache with the specified kv_cache_config."""
self.model_runner.initialize_kv_cache(kv_cache_config)
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
self.model_runner.initialize_kv_cache(kv_cache_config)
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CaMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext() # type: ignore
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def profile(self, is_start: bool = True):
if self.profiler is None: