vllm-ascend vnpu v1
This commit is contained in:
@@ -28,3 +28,4 @@ if os.getenv("DYNAMIC_EPLB", "false") == "true" or os.getenv(
|
||||
if os.getenv("SHM_BARRIER", "true") == "true":
|
||||
import vllm_ascend.patch.platform.patch_core # noqa
|
||||
import vllm_ascend.patch.platform.patch_message_queue # noqa
|
||||
import vllm_ascend.patch.platform.patch_executor # noqa
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
import signal
|
||||
from typing import Optional
|
||||
from logging import DEBUG
|
||||
import time
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.transformers_utils.config import \
|
||||
maybe_register_config_serialize_by_value
|
||||
from vllm.utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.engine.core import DPEngineCoreProc, EngineCoreProc, EngineCore
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
|
||||
@@ -66,3 +73,101 @@ def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
|
||||
|
||||
|
||||
EngineCoreProc.run_engine_core = run_engine_core
|
||||
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
# Loop until process is sent a SIGINT or SIGTERM
|
||||
while True:
|
||||
# 1) Poll the input queue until there is work to do.
|
||||
self._process_input_queue()
|
||||
# 2) Step the engine core and return the outputs.
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_IDLE_OFFLOAD and self.scheduler.has_requests() and self.model_executor.is_offloaded:
|
||||
prev_is_self = self.model_executor.reload_vram()
|
||||
if not prev_is_self:
|
||||
self.reset_prefix_cache()
|
||||
self._process_engine_step()
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_IDLE_OFFLOAD and not self.scheduler.has_requests() and not self.model_executor.is_offloaded:
|
||||
self.model_executor.offload_vram()
|
||||
|
||||
def _process_input_queue(self):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.engines_running and not self.scheduler.has_requests() \
|
||||
and not self.batch_queue:
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_IDLE_OFFLOAD and not self.model_executor.is_offloaded:
|
||||
self.model_executor.offload_vram()
|
||||
req = self.input_queue.get()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
if waited:
|
||||
logger.debug("EngineCore loop active.")
|
||||
|
||||
# Handle any more client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(*req)
|
||||
|
||||
|
||||
EngineCoreProc.run_busy_loop = run_busy_loop
|
||||
EngineCoreProc._process_input_queue = _process_input_queue
|
||||
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
||||
if has_kv_cache:
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_IDLE_OFFLOAD:
|
||||
# get available memory in idle offload mode
|
||||
available_gpu_memory = (
|
||||
self.model_executor.determine_available_memory_idle_offload_mode())
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
available_gpu_memory[0]
|
||||
elif os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
|
||||
dp_group = getattr(self, "dp_group", None)
|
||||
assert dp_group is not None
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
|
||||
available_gpu_memory = [
|
||||
self.available_gpu_memory_for_kv_cache
|
||||
] * len(kv_cache_specs)
|
||||
else:
|
||||
# Profiles the peak memory usage of the model to determine how
|
||||
# much memory can be allocated for kv cache.
|
||||
available_gpu_memory = (
|
||||
self.model_executor.determine_available_memory())
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
available_gpu_memory[0]
|
||||
else:
|
||||
# Attention free models don't need memory for kv cache
|
||||
available_gpu_memory = [0] * len(kv_cache_specs)
|
||||
|
||||
assert len(kv_cache_specs) == len(available_gpu_memory)
|
||||
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
|
||||
available_gpu_memory)
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs)
|
||||
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
||||
num_cpu_blocks = 0
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
|
||||
|
||||
EngineCore._initialize_kv_caches = _initialize_kv_caches
|
||||
|
||||
44
vllm_ascend/patch/platform/patch_executor.py
Normal file
44
vllm_ascend/patch/platform/patch_executor.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import time
|
||||
|
||||
from vllm.executor.executor_base import logger, ExecutorBase
|
||||
|
||||
|
||||
original_init = ExecutorBase.__init__
|
||||
def init(self, *args, **kwargs):
|
||||
original_init(self, *args, **kwargs)
|
||||
self.is_offloaded = False
|
||||
|
||||
|
||||
def offload_vram(self):
|
||||
if self.is_offloaded:
|
||||
logger.warning("Executor is already offloaded.")
|
||||
return
|
||||
time_before_offload = time.perf_counter()
|
||||
self.collective_rpc("offload_vram")
|
||||
time_after_offload = time.perf_counter()
|
||||
|
||||
self.is_offloaded = True
|
||||
logger.info(f"Offloading VRAM costs {time_after_offload - time_before_offload:.6f} seconds.")
|
||||
|
||||
|
||||
def reload_vram(self) -> bool:
|
||||
if not self.is_offloaded:
|
||||
logger.warning("Executor is not offloaded.")
|
||||
return True
|
||||
|
||||
time_before_reload = time.perf_counter()
|
||||
prev_is_self = self.collective_rpc("reload_vram")
|
||||
time_after_reload = time.perf_counter()
|
||||
self.is_offloaded = False
|
||||
logger.info(f"Reloading VRAM costs {time_after_reload - time_before_reload:.6f} seconds.")
|
||||
return prev_is_self
|
||||
|
||||
|
||||
def determine_available_memory_idle_offload_mode(self) -> int:
|
||||
return self.collective_rpc("determine_available_memory_idle_offload_mode")
|
||||
|
||||
|
||||
ExecutorBase.__init__ = init
|
||||
ExecutorBase.offload_vram = offload_vram
|
||||
ExecutorBase.reload_vram = reload_vram
|
||||
ExecutorBase.determine_available_memory_idle_offload_mode = determine_available_memory_idle_offload_mode
|
||||
Reference in New Issue
Block a user