fix multiproc executor determine kv cache memory & update Dockerfile
This commit is contained in:
@@ -1,20 +1,103 @@
|
||||
from logging import DEBUG
|
||||
import os
|
||||
import signal
|
||||
import queue
|
||||
import time
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.engine.core import EngineCoreProc, EngineCore
|
||||
from vllm.tracing import instrument
|
||||
from vllm.tracing import maybe_init_worker_tracer
|
||||
from vllm.transformers_utils.config import \
|
||||
maybe_register_config_serialize_by_value
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import EngineCoreProc, DPEngineCoreProc, EngineShutdownState
|
||||
from vllm.v1.engine import EngineCoreRequestType
|
||||
from vllm.v1.engine.utils import SignalCallback
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
|
||||
"""Launch EngineCore busy loop in background process."""
|
||||
|
||||
# Ensure we can serialize transformer config after spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
engine_core: EngineCoreProc | None = None
|
||||
signal_callback: SignalCallback | None = None
|
||||
try:
|
||||
vllm_config: VllmConfig = kwargs["vllm_config"]
|
||||
parallel_config: ParallelConfig = vllm_config.parallel_config
|
||||
data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
|
||||
if data_parallel:
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
process_title = f"EngineCore_DP{dp_rank}"
|
||||
else:
|
||||
process_title = "EngineCore"
|
||||
set_process_title(process_title)
|
||||
maybe_init_worker_tracer("vllm.engine_core", "engine_core", process_title)
|
||||
decorate_logs()
|
||||
|
||||
if data_parallel and vllm_config.kv_transfer_config is not None:
|
||||
# modify the engine_id and append the local_dp_rank to it to ensure
|
||||
# that the kv_transfer_config is unique for each DP rank.
|
||||
vllm_config.kv_transfer_config.engine_id = (
|
||||
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
|
||||
)
|
||||
logger.debug(
|
||||
"Setting kv_transfer_config.engine_id to %s",
|
||||
vllm_config.kv_transfer_config.engine_id,
|
||||
)
|
||||
|
||||
parallel_config.data_parallel_index = dp_rank
|
||||
if data_parallel and vllm_config.model_config.is_moe:
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||
else:
|
||||
# Non-MoE DP ranks are completely independent, so treat like DP=1.
|
||||
# Note that parallel_config.data_parallel_index will still reflect
|
||||
# the original DP rank.
|
||||
parallel_config.data_parallel_size = 1
|
||||
parallel_config.data_parallel_size_local = 1
|
||||
parallel_config.data_parallel_rank = 0
|
||||
engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
|
||||
|
||||
assert engine_core is not None
|
||||
|
||||
def wakeup_engine():
|
||||
# Wakes up idle engine via input_queue when shutdown is requested
|
||||
# Not safe in a signal handler - we may interrupt the main thread
|
||||
# while it is holding the non-reentrant input_queue.mutex
|
||||
engine_core.input_queue.put_nowait((EngineCoreRequestType.WAKEUP, None))
|
||||
|
||||
signal_callback = SignalCallback(wakeup_engine)
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
engine_core.shutdown_state = EngineShutdownState.REQUESTED
|
||||
signal_callback.trigger()
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
engine_core.run_busy_loop()
|
||||
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception as e:
|
||||
if engine_core is None:
|
||||
logger.exception("EngineCore failed to start.")
|
||||
else:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
engine_core._send_engine_dead()
|
||||
raise e
|
||||
finally:
|
||||
signal.signal(signal.SIGTERM, signal.SIG_DFL)
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
if signal_callback is not None:
|
||||
signal_callback.stop()
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
while self._handle_shutdown():
|
||||
@@ -77,75 +160,8 @@ def _process_input_queue(self):
|
||||
self._handle_client_request(*req)
|
||||
|
||||
|
||||
@instrument(span_name="Prepare model")
|
||||
def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
|
||||
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
|
||||
if has_kv_cache:
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_VNPU:
|
||||
# get available memory in idle offload mode
|
||||
available_gpu_memory = (
|
||||
self.model_executor.determine_available_memory_vnpu_offload_mode())
|
||||
self.available_gpu_memory_for_kv_cache = \
|
||||
available_gpu_memory[0]
|
||||
elif envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
|
||||
# NOTE(yongji): should already be set
|
||||
# during _eep_scale_up_before_kv_init
|
||||
assert self.available_gpu_memory_for_kv_cache > 0
|
||||
available_gpu_memory = [self.available_gpu_memory_for_kv_cache] * len(
|
||||
kv_cache_specs
|
||||
)
|
||||
else:
|
||||
# Profiles the peak memory usage of the model to determine how
|
||||
# much memory can be allocated for kv cache.
|
||||
available_gpu_memory = self.model_executor.determine_available_memory()
|
||||
self.available_gpu_memory_for_kv_cache = available_gpu_memory[0]
|
||||
else:
|
||||
# Attention free models don't need memory for kv cache
|
||||
available_gpu_memory = [0] * len(kv_cache_specs)
|
||||
|
||||
assert len(kv_cache_specs) == len(available_gpu_memory)
|
||||
|
||||
# Track max_model_len before KV cache config to detect auto-fit changes
|
||||
max_model_len_before = vllm_config.model_config.max_model_len
|
||||
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, kv_cache_specs, available_gpu_memory
|
||||
)
|
||||
|
||||
# If auto-fit reduced max_model_len, sync the new value to workers.
|
||||
# This is needed because workers were spawned before memory profiling
|
||||
# and have the original (larger) max_model_len cached.
|
||||
max_model_len_after = vllm_config.model_config.max_model_len
|
||||
if max_model_len_after != max_model_len_before:
|
||||
self.collective_rpc("update_max_model_len", args=(max_model_len_after,))
|
||||
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
|
||||
vllm_config.cache_config.num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
||||
kv_cache_groups = scheduler_kv_cache_config.kv_cache_groups
|
||||
if kv_cache_groups:
|
||||
vllm_config.cache_config.block_size = min(
|
||||
g.kv_cache_spec.block_size for g in kv_cache_groups
|
||||
)
|
||||
|
||||
vllm_config.validate_block_size()
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
elapsed = time.time() - start
|
||||
logger.info_once(
|
||||
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
return scheduler_kv_cache_config
|
||||
|
||||
# to make multi-proc enginecore get patched
|
||||
EngineCoreProc.run_engine_core = run_engine_core
|
||||
|
||||
EngineCoreProc.run_busy_loop = run_busy_loop
|
||||
EngineCoreProc._process_input_queue = _process_input_queue
|
||||
EngineCore._initialize_kv_caches = _initialize_kv_caches
|
||||
|
||||
@@ -42,11 +42,6 @@ def reload_vram(self) -> bool:
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
def determine_available_memory_vnpu_offload_mode(self) -> int:
|
||||
return self.collective_rpc("determine_available_memory_vnpu_offload_mode")
|
||||
|
||||
|
||||
Executor.is_offloaded = is_offloaded
|
||||
Executor.offload_vram = offload_vram
|
||||
Executor.reload_vram = reload_vram
|
||||
Executor.determine_available_memory_vnpu_offload_mode = determine_available_memory_vnpu_offload_mode
|
||||
|
||||
@@ -336,6 +336,25 @@ class NPUWorker(WorkerBase):
|
||||
bytes.
|
||||
"""
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_VNPU:
|
||||
allocator = CaMemAllocator.get_instance()
|
||||
free, total = allocator.get_pool_mem_info()
|
||||
if self.cache_config.gpu_memory_utilization <= 0.9:
|
||||
logger.warning(
|
||||
"GPU memory utilization is set to %.2f. For VNPU mode, it is recommended to set gpu_memory_utilization to a larger value",
|
||||
self.cache_config.gpu_memory_utilization,
|
||||
)
|
||||
available_kv_cache_memory = int(
|
||||
total * self.cache_config.gpu_memory_utilization - (total - free)
|
||||
)
|
||||
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
|
||||
self.available_kv_cache_memory_bytes = available_kv_cache_memory
|
||||
logger.info_once(
|
||||
"Available KV cache memory: %.2f GiB",
|
||||
GiB(self.available_kv_cache_memory_bytes),
|
||||
scope="local",
|
||||
)
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
@@ -363,28 +382,6 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
@torch.inference_mode()
|
||||
def determine_available_memory_vnpu_offload_mode(self) -> int:
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
allocator = CaMemAllocator.get_instance()
|
||||
free, total = allocator.get_pool_mem_info()
|
||||
if self.cache_config.gpu_memory_utilization <= 0.9:
|
||||
logger.warning(
|
||||
"GPU memory utilization is set to %.2f. For VNPU mode, it is recommended to set gpu_memory_utilization to a larger value",
|
||||
self.cache_config.gpu_memory_utilization,
|
||||
)
|
||||
available_kv_cache_memory = int(
|
||||
total * self.cache_config.gpu_memory_utilization - (total - free)
|
||||
)
|
||||
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
|
||||
self.available_kv_cache_memory_bytes = available_kv_cache_memory
|
||||
logger.info_once(
|
||||
"Available KV cache memory: %.2f GiB",
|
||||
GiB(self.available_kv_cache_memory_bytes),
|
||||
scope="local",
|
||||
)
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
|
||||
Reference in New Issue
Block a user