adapt to vllm-ascend v0.18.0rc1
Some checks failed
Merge Conflict Labeler / main (push) Has been cancelled
Some checks failed
Merge Conflict Labeler / main (push) Has been cancelled
This commit is contained in:
@@ -37,3 +37,6 @@ if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXP
|
||||
|
||||
if envs.VLLM_ASCEND_BALANCE_SCHEDULING:
|
||||
import vllm_ascend.patch.platform.patch_balance_schedule # noqa
|
||||
|
||||
import vllm_ascend.patch.platform.patch_executor # noqa
|
||||
import vllm_ascend.patch.platform.patch_core # noqa
|
||||
|
||||
151
vllm_ascend/patch/platform/patch_core.py
Normal file
151
vllm_ascend/patch/platform/patch_core.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from logging import DEBUG
|
||||
import os
|
||||
import queue
|
||||
import time
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.engine.core import EngineCoreProc, EngineCore
|
||||
from vllm.tracing import instrument
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
while self._handle_shutdown():
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
if (
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_VNPU
|
||||
and self.has_work()
|
||||
and self.model_executor.is_offloaded()
|
||||
):
|
||||
prev_is_self = self.model_executor.reload_vram()
|
||||
if not prev_is_self:
|
||||
self.reset_prefix_cache()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
self._process_engine_step()
|
||||
if (
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_VNPU
|
||||
and not self.has_work()
|
||||
and not self.model_executor.is_offloaded()
|
||||
):
|
||||
self.model_executor.offload_vram()
|
||||
|
||||
raise SystemExit
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.has_work() and self.is_running():
|
||||
# Notify callbacks waiting for engine to become idle.
|
||||
self._notify_idle_state_callbacks()
|
||||
if self.input_queue.empty():
|
||||
# Drain aborts queue; all aborts are also processed via input_queue.
|
||||
with self.aborts_queue.mutex:
|
||||
self.aborts_queue.queue.clear()
|
||||
if logger.isEnabledFor(DEBUG):
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
# vnpu offload if idle
|
||||
if (
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_VNPU
|
||||
and not self.model_executor.is_offloaded()
|
||||
):
|
||||
self.model_executor.offload_vram()
|
||||
block = self.process_input_queue_block
|
||||
try:
|
||||
req = self.input_queue.get(block=block)
|
||||
self._handle_client_request(*req)
|
||||
except queue.Empty:
|
||||
break
|
||||
if not block:
|
||||
break
|
||||
|
||||
if waited:
|
||||
logger.debug("EngineCore loop active.")
|
||||
|
||||
# Handle any more client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
|
||||
@instrument(span_name="Prepare model")
|
||||
def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
||||
if has_kv_cache:
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_VNPU:
|
||||
# get available memory in idle offload mode
|
||||
available_gpu_memory = (
|
||||
self.model_executor.determine_available_memory_vnpu_offload_mode())
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
available_gpu_memory[0]
|
||||
elif envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
# NOTE(yongji): should already be set
|
||||
# during _eep_scale_up_before_kv_init
|
||||
assert self.available_gpu_memory_for_kv_cache > 0
|
||||
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
|
||||
kv_cache_specs
|
||||
)
|
||||
else:
|
||||
# Profiles the peak memory usage of the model to determine how
|
||||
# much memory can be allocated for kv cache.
|
||||
available_gpu_memory = self.model_executor.determine_available_memory()
|
||||
self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
|
||||
else:
|
||||
# Attention free models don't need memory for kv cache
|
||||
available_gpu_memory = [0] * len(kv_cache_specs)
|
||||
|
||||
assert len(kv_cache_specs) == len(available_gpu_memory)
|
||||
|
||||
# Track max_model_len before KV cache config to detect auto-fit changes
|
||||
max_model_len_before = vllm_config.model_config.max_model_len
|
||||
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, kv_cache_specs, available_gpu_memory
|
||||
)
|
||||
|
||||
# If auto-fit reduced max_model_len, sync the new value to workers.
|
||||
# This is needed because workers were spawned before memory profiling
|
||||
# and have the original (larger) max_model_len cached.
|
||||
max_model_len_after = vllm_config.model_config.max_model_len
|
||||
if max_model_len_after != max_model_len_before:
|
||||
self.collective_rpc("update_max_model_len", args=(max_model_len_after,))
|
||||
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
|
||||
vllm_config.cache_config.num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
||||
kv_cache_groups = scheduler_kv_cache_config.kv_cache_groups
|
||||
if kv_cache_groups:
|
||||
vllm_config.cache_config.block_size = min(
|
||||
g.kv_cache_spec.block_size for g in kv_cache_groups
|
||||
)
|
||||
|
||||
vllm_config.validate_block_size()
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info_once(
|
||||
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
return scheduler_kv_cache_config
|
||||
|
||||
|
||||
EngineCoreProc.run_busy_loop = run_busy_loop
|
||||
EngineCoreProc._process_input_queue = _process_input_queue
|
||||
EngineCore._initialize_kv_caches = _initialize_kv_caches
|
||||
52
vllm_ascend/patch/platform/patch_executor.py
Normal file
52
vllm_ascend/patch/platform/patch_executor.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import time
|
||||
|
||||
from vllm.v1.executor.abstract import logger, Executor
|
||||
|
||||
|
||||
def is_offloaded(self) -> bool:
|
||||
if not hasattr(self, "_is_offloaded"):
|
||||
self._is_offloaded = False
|
||||
return self._is_offloaded
|
||||
|
||||
def offload_vram(self):
|
||||
if self.is_offloaded():
|
||||
logger.warning("Executor is already offloaded.")
|
||||
return
|
||||
time_before_offload = time.perf_counter()
|
||||
self.collective_rpc("offload_vram")
|
||||
time_after_offload = time.perf_counter()
|
||||
|
||||
self._is_offloaded = True
|
||||
logger.info(f"Offloading VRAM costs {time_after_offload - time_before_offload:.6f} seconds.")
|
||||
|
||||
|
||||
def reload_vram(self) -> bool:
|
||||
if not self.is_offloaded():
|
||||
logger.warning("Executor is not offloaded.")
|
||||
return True
|
||||
|
||||
while True:
|
||||
time_before_reload = time.perf_counter()
|
||||
res = self.collective_rpc("try_reload_vram")
|
||||
time_after_reload = time.perf_counter()
|
||||
|
||||
succ = all(x[0] for x in res)
|
||||
if succ:
|
||||
self._is_offloaded = False
|
||||
logger.info(f"Reloading VRAM costs {time_after_reload - time_before_reload:.6f} seconds.")
|
||||
prev_is_self = all(x[1] for x in res)
|
||||
return prev_is_self
|
||||
else:
|
||||
# some workers not get lock
|
||||
self.collective_rpc("vnpu_unlock_gpu")
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
def determine_available_memory_vnpu_offload_mode(self) -> int:
|
||||
return self.collective_rpc("determine_available_memory_vnpu_offload_mode")
|
||||
|
||||
|
||||
Executor.is_offloaded = is_offloaded
|
||||
Executor.offload_vram = offload_vram
|
||||
Executor.reload_vram = reload_vram
|
||||
Executor.determine_available_memory_vnpu_offload_mode = determine_available_memory_vnpu_offload_mode
|
||||
Reference in New Issue
Block a user