diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index feccb8a7..a03c0b3b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2330,6 +2330,7 @@ class NPUModelRunner(GPUModelRunner): if self.lora_config: self.model = self.load_lora_model(self.model, self.vllm_config, self.device) + self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) # wrap the model with full graph wrapper if needed. diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index c058e59b..be6070a2 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -37,6 +37,7 @@ from vllm.lora.request import LoRARequest from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask from vllm.utils.mem_constants import GiB_bytes +from vllm.utils.mem_utils import MemorySnapshot, memory_profiling from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec @@ -249,8 +250,25 @@ class NPUWorker(WorkerBase): def _init_device(self): device = torch.device(f"npu:{self.local_rank}") torch.npu.set_device(device) + + gc.collect() torch.npu.empty_cache() + # take current memory snapshot + self.init_snapshot = MemorySnapshot() + self.requested_memory = self.init_snapshot.total_memory * self.cache_config.gpu_memory_utilization + if self.init_snapshot.free_memory < self.requested_memory: + GiB = lambda b: round(b / GiB_bytes, 2) + raise ValueError( + f"Free memory on device " + f"({GiB(self.init_snapshot.free_memory)}/" + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " + f"is less than desired GPU memory utilization " + f"({self.cache_config.gpu_memory_utilization}, " + f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " + f"utilization or reduce GPU memory used by other processes." + ) + if ( self.parallel_config.data_parallel_size > 1 and self.parallel_config.data_parallel_size_local > 0 @@ -265,7 +283,6 @@ class NPUWorker(WorkerBase): f"({visible_device_count})." ) - self.init_npu_memory = torch.npu.mem_get_info()[0] # Initialize the distributed environment. self._init_worker_distributed_environment() # Set random seed. @@ -300,43 +317,51 @@ class NPUWorker(WorkerBase): @torch.inference_mode() def determine_available_memory(self) -> int: - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() + """Profiles the peak memory usage of the model to determine how much + memory can be used for KV cache without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculates the free memory that can be used for KV cache in + bytes. + """ + GiB = lambda b: b / GiB_bytes # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - _, total_npu_memory = torch.npu.mem_get_info() - self.model_runner.profile_run() + with memory_profiling( + self.init_snapshot, + weights_memory=int(self.model_runner.model_memory_usage), + ) as profile_result: + self.model_runner.profile_run() + free_memory, total_memory = torch.npu.mem_get_info() + torch_memory = torch.npu.memory_reserved() + non_torch_memory_before_empty_cache = total_memory - free_memory - torch_memory - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - free_npu_memory, _ = torch.npu.mem_get_info() - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - assert self.init_npu_memory > free_npu_memory, ( + self.non_torch_memory = profile_result.non_torch_increase + self.peak_activation_memory = profile_result.torch_peak_increase + non_torch_memory_cleared_by_empty_cache = non_torch_memory_before_empty_cache - self.non_torch_memory + + free_gpu_memory = profile_result.after_profile.free_memory + assert self.init_snapshot.free_memory > free_gpu_memory, ( "Error in memory profiling. " - f"Initial free memory {self.init_npu_memory}, current free memory" - f" {free_npu_memory}. This happens when the NPU memory was " - "not properly cleaned up before initializing the vLLM instance." + f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, " + f"current free memory {GiB(free_gpu_memory)} GiB. " + "This happens when other processes sharing the same container " + "release GPU memory while vLLM is profiling during initialization. " + "To fix this, ensure consistent GPU memory allocation or " + "isolate vLLM in its own container." + ) + self.available_kv_cache_memory_bytes = ( + self.requested_memory - profile_result.non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache ) - # Get the peak memory allocation recorded by torch - peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"] - # TODO: don`t need impl this func after empty_cache in - # Worker.determine_num_available_blocks() unified` - torch.npu.empty_cache() - torch_allocated_bytes = torch_npu.npu.memory_stats()["allocated_bytes.all.current"] - total_allocated_bytes = torch_npu.npu.mem_get_info()[1] - torch_npu.npu.mem_get_info()[0] - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - available_kv_cache_memory = int(total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory) - available_kv_cache_memory = int(max(available_kv_cache_memory, 0)) - logger.info(f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}") - return available_kv_cache_memory + logger.debug(profile_result) + logger.info_once( + "Available KV cache memory: %.2f GiB", + GiB(self.available_kv_cache_memory_bytes), + scope="local", + ) + return int(self.available_kv_cache_memory_bytes) def execute_model( self,