[Refactor][Bugfix] Use upstream mem_utils for profiling and correct non-torch memory recorded during profiling (#6625)

### What this PR does / why we need it?

1. Following https://github.com/vllm-project/vllm/pull/32322, use the
`memory_profiling` context manager from vllm for profiling.
2. Fix wrong non-torch memory value recorded during profiling, which is
not its peak during inference.

---
**More details about point 2:**

After profling, the non-torch memory value we recorded is lower than
that in real inference. This is mainly because of the different memory
management behaviour between `torch.cuda.empty_cache()` and
`torch.npu.empty_cache()`.

With regard to `torch.cuda.empty_cache()`, it only recycle the unused
memory in pytorch memory pool (i.e., memory managed by pytorch caching
allocator), **with no affect to non-torch memory**. However, as for
`torch.npu.empty_cache()`, it has a totally different memory management
mechanism, i.e., it may call `aclrtSynchronize` and **enable Ascend
runtime to free up non-torch memory**.

Thus, the non-torch memory value we recorded after
`torch.npu.empty_cache()` is much lower than its peak during profling.

Resolution:

We record the peak non-torch memory value
(`non_torch_memory_before_empty_cache`) after profiling, but before
`torch.npu.empty_cache()`. Then, we add the diff
(`non_torch_memory_cleared_by_empty_cache =
non_torch_memory_before_empty_cache - self.non_torch_memory`) to
non-torch memory when calculating available KV cache memory, which will
lead to less KV cache memory (i.e., it's safer to avoid OOM issues).

---
> [!NOTE]
> This PR needs to wait for main2main aligning to latest vllm commit
before merging.

### Does this PR introduce _any_ user-facing change?
no.

### How was this patch tested?

Before this PR, the non-torch memory we used to calculate available KV
cache memory is **0.90 G**, whereas its peak during real inference is
**1.08 G**, diff: **182.00 M**.

After this PR, we add this diff to non-torch memory after profiling and
thus make the profiling results more accurate.
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2026-02-25 14:28:08 +08:00
committed by GitHub
parent 812c722cfb
commit 957804df56
2 changed files with 57 additions and 31 deletions

View File

@@ -2330,6 +2330,7 @@ class NPUModelRunner(GPUModelRunner):
if self.lora_config: if self.lora_config:
self.model = self.load_lora_model(self.model, self.vllm_config, self.device) self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30)) logger.info("Loading model weights took %.4f GB", m.consumed_memory / float(2**30))
# wrap the model with full graph wrapper if needed. # wrap the model with full graph wrapper if needed.

View File

@@ -37,6 +37,7 @@ from vllm.lora.request import LoRARequest
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
@@ -249,8 +250,25 @@ class NPUWorker(WorkerBase):
def _init_device(self): def _init_device(self):
device = torch.device(f"npu:{self.local_rank}") device = torch.device(f"npu:{self.local_rank}")
torch.npu.set_device(device) torch.npu.set_device(device)
gc.collect()
torch.npu.empty_cache() torch.npu.empty_cache()
# take current memory snapshot
self.init_snapshot = MemorySnapshot()
self.requested_memory = self.init_snapshot.total_memory * self.cache_config.gpu_memory_utilization
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
if ( if (
self.parallel_config.data_parallel_size > 1 self.parallel_config.data_parallel_size > 1
and self.parallel_config.data_parallel_size_local > 0 and self.parallel_config.data_parallel_size_local > 0
@@ -265,7 +283,6 @@ class NPUWorker(WorkerBase):
f"({visible_device_count})." f"({visible_device_count})."
) )
self.init_npu_memory = torch.npu.mem_get_info()[0]
# Initialize the distributed environment. # Initialize the distributed environment.
self._init_worker_distributed_environment() self._init_worker_distributed_environment()
# Set random seed. # Set random seed.
@@ -300,43 +317,51 @@ class NPUWorker(WorkerBase):
@torch.inference_mode() @torch.inference_mode()
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
# Profile the memory usage of the model and get the maximum number of """Profiles the peak memory usage of the model to determine how much
# cache blocks that can be allocated with the remaining free memory. memory can be used for KV cache without OOMs.
gc.collect()
torch.npu.empty_cache() The engine will first conduct a profiling of the existing memory usage.
torch.npu.reset_peak_memory_stats() Then, it calculates the free memory that can be used for KV cache in
bytes.
"""
GiB = lambda b: b / GiB_bytes
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
_, total_npu_memory = torch.npu.mem_get_info() with memory_profiling(
self.model_runner.profile_run() self.init_snapshot,
weights_memory=int(self.model_runner.model_memory_usage),
) as profile_result:
self.model_runner.profile_run()
free_memory, total_memory = torch.npu.mem_get_info()
torch_memory = torch.npu.memory_reserved()
non_torch_memory_before_empty_cache = total_memory - free_memory - torch_memory
# Calculate the number of blocks that can be allocated with the self.non_torch_memory = profile_result.non_torch_increase
# profiled peak memory. self.peak_activation_memory = profile_result.torch_peak_increase
free_npu_memory, _ = torch.npu.mem_get_info() non_torch_memory_cleared_by_empty_cache = non_torch_memory_before_empty_cache - self.non_torch_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling. free_gpu_memory = profile_result.after_profile.free_memory
assert self.init_npu_memory > free_npu_memory, ( assert self.init_snapshot.free_memory > free_gpu_memory, (
"Error in memory profiling. " "Error in memory profiling. "
f"Initial free memory {self.init_npu_memory}, current free memory" f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
f" {free_npu_memory}. This happens when the NPU memory was " f"current free memory {GiB(free_gpu_memory)} GiB. "
"not properly cleaned up before initializing the vLLM instance." "This happens when other processes sharing the same container "
"release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container."
)
self.available_kv_cache_memory_bytes = (
self.requested_memory - profile_result.non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
) )
# Get the peak memory allocation recorded by torch logger.debug(profile_result)
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"] logger.info_once(
# TODO: don`t need impl this func after empty_cache in "Available KV cache memory: %.2f GiB",
# Worker.determine_num_available_blocks() unified` GiB(self.available_kv_cache_memory_bytes),
torch.npu.empty_cache() scope="local",
torch_allocated_bytes = torch_npu.npu.memory_stats()["allocated_bytes.all.current"] )
total_allocated_bytes = torch_npu.npu.mem_get_info()[1] - torch_npu.npu.mem_get_info()[0] return int(self.available_kv_cache_memory_bytes)
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations
available_kv_cache_memory = int(total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory)
available_kv_cache_memory = int(max(available_kv_cache_memory, 0))
logger.info(f"Available memory: {available_kv_cache_memory}, total memory: {total_npu_memory}")
return available_kv_cache_memory
def execute_model( def execute_model(
self, self,