[Bugfix] Fix multi-instance serving OOM on single card (#7427)

### What this PR does / why we need it?
Fix https://github.com/vllm-project/vllm-ascend/issues/7308.

Subtracting `init_non_torch_memory` (maybe used by the first instance)
from the total `non_torch_memory` when calculating
`available_kv_cache_memory`.

Directly use `non_torch_memory_increase` (contained in
`non_kv_cache_memory`) to calculate `available_kv_cache_memory`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Launch tow vllm-ascend instances sequentially on single card.

```bash
# Launch first instance
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B \
--port 8100 \
--host 0.0.0.0 \
--additional-config='{"enable_cpu_binding":true}'  \
--gpu-memory-utilization 0.3 \
--max-num-seqs 1 \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--no-enable-prefix-caching \
--enforce-eager

# Launch second instance
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B \
--port 8101 \
--host 0.0.0.0 \
--additional-config='{"enable_cpu_binding":true}'  \
--gpu-memory-utilization 0.3 \
--max-num-seqs 1 \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--no-enable-prefix-caching \
--enforce-eager
```

**Before this PR:**

```bash
# First instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2340388298034668 GiB
init_non_torch_memory: 0.3616676330566406 GiB
non_torch_memory_before_empty_cache: 0.3896217346191406 GiB
non_torch_memory_increase: 0.0279541015625 GiB
non_torch_memory_cleared_by_empty_cache: 0.3616676330566406 GiB
------------------------------------------------------------------

# Second instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2336344718933105 GiB
init_non_torch_memory: 18.37220001220703 GiB
non_torch_memory_before_empty_cache: 18.399906158447266 GiB
non_torch_memory_increase: 0.02754974365234375 GiB
non_torch_memory_cleared_by_empty_cache: 18.372356414794922 GiB
------------------------------------------------------------------
# available_kv_cache_memory = requested_memory - non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
Available KV cache memory: -1.32 GiB
```

**After this PR:**

```bash
# First instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2340540885925293 GiB
init_non_torch_memory: 0.36182403564453125 GiB
non_torch_memory_before_empty_cache: 0.38979339599609375 GiB
non_torch_memory_increase: 0.0279693603515625 GiB
non_torch_memory_cleared_by_empty_cache: 0.0 GiB
------------------------------------------------------------------

# Second instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.233344554901123 GiB
init_non_torch_memory: 18.74309539794922 GiB
non_torch_memory_before_empty_cache: 18.770355224609375 GiB
non_torch_memory_increase: 0.02725982666015625 GiB
non_torch_memory_cleared_by_empty_cache: 0.0 GiB
------------------------------------------------------------------
# available_kv_cache_memory = requested_memory - non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
Available KV cache memory: 17.05 GiB
```

- vLLM version: v0.17.0
- vLLM main:
4497431df6

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
This commit is contained in:
Shanshan Shen
2026-03-23 14:22:59 +08:00
committed by GitHub
parent 44ef9a36ac
commit 5c0d02f689
5 changed files with 348 additions and 14 deletions

View File

@@ -41,6 +41,8 @@ e2e-singlecard:
estimated_time: 258 estimated_time: 258
- name: tests/e2e/singlecard/test_vlm.py - name: tests/e2e/singlecard/test_vlm.py
estimated_time: 495 estimated_time: 495
- name: tests/e2e/singlecard/test_multi_instance.py
estimated_time: 120
- name: tests/e2e/singlecard/test_xlite.py - name: tests/e2e/singlecard/test_xlite.py
estimated_time: 135 estimated_time: 135
- name: tests/e2e/singlecard/compile/test_norm_quant_fusion.py - name: tests/e2e/singlecard/compile/test_norm_quant_fusion.py

View File

@@ -264,6 +264,7 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
additional_config={"layer_sharding": ["q_b_proj", "o_proj"]}, additional_config={"layer_sharding": ["q_b_proj", "o_proj"]},
reasoning_parser="deepseek_v3", reasoning_parser="deepseek_v3",
tokenizer_mode="deepseek_v32", tokenizer_mode="deepseek_v32",
gpu_memory_utilization=0.8,
) as vllm_model: ) as vllm_model:
vllm_model.generate_greedy(short_example_prompts, max_tokens) vllm_model.generate_greedy(short_example_prompts, max_tokens)
vllm_model.generate_greedy(long_example_prompts, max_tokens) vllm_model.generate_greedy(long_example_prompts, max_tokens)
@@ -292,6 +293,7 @@ def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
additional_config={"layer_sharding": ["q_b_proj", "o_proj"], "enable_sparse_c8": True}, additional_config={"layer_sharding": ["q_b_proj", "o_proj"], "enable_sparse_c8": True},
reasoning_parser="deepseek_v3", reasoning_parser="deepseek_v3",
tokenizer_mode="deepseek_v32", tokenizer_mode="deepseek_v32",
gpu_memory_utilization=0.8,
) as vllm_model: ) as vllm_model:
vllm_model.generate_greedy(short_example_prompts, max_tokens) vllm_model.generate_greedy(short_example_prompts, max_tokens)
vllm_model.generate_greedy(long_example_prompts, max_tokens) vllm_model.generate_greedy(long_example_prompts, max_tokens)

View File

@@ -0,0 +1,93 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Two VllmRunner instances are nested so that the first instance's worker
process is still holding NPU memory when the second instance's worker process
starts. Both instances must:
1. Initialize without raising any exception (no OOM during
determine_available_memory / KV-cache allocation).
2. Successfully complete a short generation request.
The model is Qwen/Qwen3-0.6B (~0.5 GiB weights) and gpu_memory_utilization
is set to 0.4 per instance so that two instances comfortably fit on a single
64 GiB Ascend 910B card while leaving enough headroom to avoid the
pre-fix negative-KV-cache condition.
"""
import os
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODEL = "Qwen/Qwen3-0.6B"
_PROMPTS = ["Hello, my name is"]
_MAX_TOKENS = 5
# Use a low utilization so two instances fit side-by-side on one card:
# 2 × 0.4 × card_total ≤ card_total (holds for any card ≥ 1 GiB)
_GPU_MEM_UTIL = 0.4
_MAX_MODEL_LEN = 512
def test_two_instances_on_single_card() -> None:
"""
Regression test for PR #7427 (multi-instance OOM on single card).
Start a first vllm-ascend instance; while it is still running and holding
NPU memory, start a second instance with identical settings. Both must
initialize correctly and produce non-empty outputs.
Failure signature (pre-fix):
RuntimeError / ValueError during the second instance's init, or
"Available KV cache memory: -X.XX GiB" in the logs followed by
zero KV blocks being allocated.
"""
# ── First instance ──────────────────────────────────────────────────
with VllmRunner(
MODEL,
max_model_len=_MAX_MODEL_LEN,
gpu_memory_utilization=_GPU_MEM_UTIL,
enforce_eager=True,
) as runner1:
# ── Second instance starts while first is still alive ────────────
# This is the exact scenario from PR #7427: the second worker process
# sees a reduced init_snapshot.free_memory because the first instance's
# worker is still holding NPU memory.
with VllmRunner(
MODEL,
max_model_len=_MAX_MODEL_LEN,
gpu_memory_utilization=_GPU_MEM_UTIL,
enforce_eager=True,
) as runner2:
outputs2 = runner2.generate_greedy(_PROMPTS, max_tokens=_MAX_TOKENS)
outputs1 = runner1.generate_greedy(_PROMPTS, max_tokens=_MAX_TOKENS)
# ── Assertions ───────────────────────────────────────────────────────
assert outputs1, "First instance produced no outputs"
assert outputs2, "Second instance produced no outputs"
_, text1 = outputs1[0]
_, text2 = outputs2[0]
assert text1, "First instance output text is empty — model may have failed to run"
assert text2, (
"Second instance output text is empty — "
"KV cache may have been allocated with zero blocks (pre-fix OOM regression)"
)

View File

@@ -0,0 +1,248 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import MagicMock, patch
from vllm.utils.mem_constants import GiB_bytes
from tests.ut.base import TestBase
class TestDetermineAvailableMemoryMultiInstance(TestBase):
"""Tests for determine_available_memory() focusing on the multi-instance
OOM regression (PR #7427)."""
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
def _make_worker(
self,
requested_memory: int,
init_free_memory: int,
init_total_memory: int,
model_memory_usage: int | None = None,
):
"""Return a minimally-configured NPUWorker mock with memory state set."""
from vllm_ascend.worker.worker import NPUWorker
if model_memory_usage is None:
model_memory_usage = int(0.5 * GiB_bytes) # Qwen3-0.6B ~0.5 GiB
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
worker = NPUWorker()
worker.model_runner = MagicMock()
worker.model_runner.model_memory_usage = model_memory_usage
mock_snapshot = MagicMock()
mock_snapshot.free_memory = init_free_memory
mock_snapshot.total_memory = init_total_memory
worker.init_snapshot = mock_snapshot
worker.requested_memory = requested_memory
return worker
@staticmethod
def _make_profile_result(free_memory_after: int, non_kv_cache_memory: int):
"""Return a mock profile_result compatible with memory_profiling output."""
profile_result = MagicMock()
profile_result.after_profile.free_memory = free_memory_after
profile_result.non_kv_cache_memory = non_kv_cache_memory
return profile_result
@staticmethod
def _patch_memory_profiling(profile_result):
"""Return a mock for `memory_profiling` that yields *profile_result*."""
mock_ctx = MagicMock()
mock_ctx.__enter__ = MagicMock(return_value=profile_result)
mock_ctx.__exit__ = MagicMock(return_value=False)
mock_profiling = MagicMock(return_value=mock_ctx)
return patch("vllm_ascend.worker.worker.memory_profiling", mock_profiling)
# ------------------------------------------------------------------ #
# Tests
# ------------------------------------------------------------------ #
@patch("vllm_ascend.worker.worker.logger")
def test_single_instance_positive_kv_cache(self, mock_logger):
"""Baseline: single instance on an empty card yields positive KV cache."""
total = int(64 * GiB_bytes)
gpu_util = 0.9
requested_memory = int(total * gpu_util) # 57.6 GiB
init_free = int(62 * GiB_bytes) # almost all free
non_kv_cache = int(0.5 * GiB_bytes) # Qwen3-0.6B weights
worker = self._make_worker(requested_memory, init_free, total)
profile_result = self._make_profile_result(
free_memory_after=init_free - non_kv_cache,
non_kv_cache_memory=non_kv_cache,
)
with self._patch_memory_profiling(profile_result):
result = worker.determine_available_memory()
expected = requested_memory - non_kv_cache
self.assertEqual(result, expected)
self.assertGreater(result, 0)
@patch("vllm_ascend.worker.worker.logger")
def test_second_instance_on_same_card_positive_kv_cache(self, mock_logger):
"""
Regression test for PR #7427.
Scenario (64 GiB Ascend 910B card, two Qwen3-0.6B instances,
gpu_memory_utilization=0.4):
┌───────────────────────────────────────────────────────────────┐
│ Card total: 64 GiB │
│ Instance 1: requested_memory = 64 * 0.4 = 25.6 GiB (in use) │
│ Instance 2 start: init_snapshot.free_memory ≈ 38.4 GiB │
│ Instance 2: requested_memory = 25.6 GiB │
│ Profiling (fixed): non_kv_cache_memory = 0.5 GiB (weights) │
│ available = 25.6 - 0.5 = 25.1 GiB → must be > 0 ✓ │
└───────────────────────────────────────────────────────────────┘
Before the fix, non_kv_cache_memory was inflated to include first
instance memory (~25.6 GiB), yielding available ≈ -1.32 GiB (OOM).
"""
total = int(64 * GiB_bytes)
gpu_util = 0.4
requested_memory = int(total * gpu_util) # 25.6 GiB
# First instance already occupies its full requested_memory slice
first_instance_used = requested_memory # 25.6 GiB
init_free = total - first_instance_used # ~38.4 GiB
# After the fix: profiling correctly reports only the second
# instance's own model weights, not the first instance's memory.
non_kv_cache = int(0.5 * GiB_bytes) # Qwen3-0.6B weights
worker = self._make_worker(requested_memory, init_free, total)
profile_result = self._make_profile_result(
free_memory_after=init_free - non_kv_cache,
non_kv_cache_memory=non_kv_cache,
)
with self._patch_memory_profiling(profile_result):
result = worker.determine_available_memory()
self.assertGreater(
result, 0,
"Second instance must have positive KV cache memory. "
"A non-positive value means the multi-instance OOM bug "
"(PR #7427) has regressed.",
)
expected = requested_memory - non_kv_cache
self.assertEqual(result, expected)
# Verify model_runner.profile_run() was called during profiling
worker.model_runner.profile_run.assert_called_once()
@patch("vllm_ascend.worker.worker.logger")
def test_second_instance_buggy_non_kv_cache_gives_negative(self, mock_logger):
"""
Documents the *pre-fix* buggy behaviour that PR #7427 addresses.
When non_kv_cache_memory is erroneously inflated to include memory
already held by the first instance (~25.6 GiB extra), the formula
available = requested_memory - non_kv_cache_memory
yields a negative value, confirming why the fix was necessary.
This test is intentionally asserting the *negative* outcome to
document the regressed state; it is NOT testing the fix itself.
"""
total = int(64 * GiB_bytes)
gpu_util = 0.4
requested_memory = int(total * gpu_util) # 25.6 GiB
first_instance_used = requested_memory # 25.6 GiB
init_free = total - first_instance_used # ~38.4 GiB
# Buggy: non_kv_cache_memory = first-instance memory + second-instance weights
buggy_non_kv_cache = int((25.6 + 0.5) * GiB_bytes) # ~26.1 GiB
worker = self._make_worker(requested_memory, init_free, total)
profile_result = self._make_profile_result(
# free_memory decreased only by the actual new allocation (weights)
free_memory_after=init_free - int(0.5 * GiB_bytes),
non_kv_cache_memory=buggy_non_kv_cache,
)
with self._patch_memory_profiling(profile_result):
result = worker.determine_available_memory()
# Pre-fix: 25.6 GiB - 26.1 GiB = -0.5 GiB (negative → OOM)
self.assertLess(
result, 0,
"With the pre-fix (buggy) non_kv_cache_memory the result must be "
"negative; this documents the OOM regression that PR #7427 fixed.",
)
@patch("vllm_ascend.worker.worker.logger")
def test_assert_raises_when_free_memory_increases_after_profile(self, mock_logger):
"""
determine_available_memory() must raise AssertionError when free memory
after profiling is greater than before (external process released memory
during profiling, invalidating the measurement).
"""
total = int(64 * GiB_bytes)
requested_memory = int(total * 0.9)
init_free = int(60 * GiB_bytes)
worker = self._make_worker(requested_memory, init_free, total)
# Abnormal: free memory increased after profiling
profile_result = self._make_profile_result(
free_memory_after=init_free + int(1 * GiB_bytes), # went UP
non_kv_cache_memory=int(0.5 * GiB_bytes),
)
with self._patch_memory_profiling(profile_result):
with self.assertRaises(AssertionError) as ctx:
worker.determine_available_memory()
self.assertIn("Error in memory profiling", str(ctx.exception))
@patch("vllm_ascend.worker.worker.logger")
def test_second_instance_tight_memory_still_positive(self, mock_logger):
"""
Edge case: card is almost full when second instance starts.
Even with very little free memory left, as long as requested_memory >
non_kv_cache_memory (i.e. there is room for at least some KV blocks),
the result must be positive.
"""
total = int(32 * GiB_bytes) # smaller card (e.g. 910B1)
gpu_util = 0.3
requested_memory = int(total * gpu_util) # 9.6 GiB
# First instance has consumed most of its requested slice
first_instance_used = requested_memory # 9.6 GiB
init_free = total - first_instance_used # 22.4 GiB
non_kv_cache = int(0.5 * GiB_bytes) # Qwen3-0.6B
worker = self._make_worker(requested_memory, init_free, total)
profile_result = self._make_profile_result(
free_memory_after=init_free - non_kv_cache,
non_kv_cache_memory=non_kv_cache,
)
with self._patch_memory_profiling(profile_result):
result = worker.determine_available_memory()
self.assertGreater(result, 0)
self.assertEqual(result, requested_memory - non_kv_cache)

View File

@@ -341,13 +341,6 @@ class NPUWorker(WorkerBase):
weights_memory=int(self.model_runner.model_memory_usage), weights_memory=int(self.model_runner.model_memory_usage),
) as profile_result: ) as profile_result:
self.model_runner.profile_run() self.model_runner.profile_run()
free_memory, total_memory = torch.npu.mem_get_info()
torch_memory = torch.npu.memory_reserved()
non_torch_memory_before_empty_cache = total_memory - free_memory - torch_memory
self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase
non_torch_memory_cleared_by_empty_cache = non_torch_memory_before_empty_cache - self.non_torch_memory
free_gpu_memory = profile_result.after_profile.free_memory free_gpu_memory = profile_result.after_profile.free_memory
assert self.init_snapshot.free_memory > free_gpu_memory, ( assert self.init_snapshot.free_memory > free_gpu_memory, (
@@ -359,16 +352,12 @@ class NPUWorker(WorkerBase):
"To fix this, ensure consistent GPU memory allocation or " "To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container." "isolate vLLM in its own container."
) )
self.available_kv_cache_memory_bytes = ( self.available_kv_cache_memory_bytes = self.requested_memory - profile_result.non_kv_cache_memory
self.requested_memory - profile_result.non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
)
logger.debug(profile_result) logger.debug(profile_result)
logger.info_once( logger.info_once(
"Available KV cache memory: %.2f GiB", "Available KV cache memory: %.2f GiB", GiB(self.available_kv_cache_memory_bytes), scope="local"
GiB(self.available_kv_cache_memory_bytes),
scope="local",
) )
return int(self.available_kv_cache_memory_bytes) return int(self.available_kv_cache_memory_bytes)
def execute_model( def execute_model(