### What this PR does / why we need it?
Fix https://github.com/vllm-project/vllm-ascend/issues/7308.
Subtracting `init_non_torch_memory` (maybe used by the first instance)
from the total `non_torch_memory` when calculating
`available_kv_cache_memory`.
Directly use `non_torch_memory_increase` (contained in
`non_kv_cache_memory`) to calculate `available_kv_cache_memory`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Launch tow vllm-ascend instances sequentially on single card.
```bash
# Launch first instance
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B \
--port 8100 \
--host 0.0.0.0 \
--additional-config='{"enable_cpu_binding":true}' \
--gpu-memory-utilization 0.3 \
--max-num-seqs 1 \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--no-enable-prefix-caching \
--enforce-eager
# Launch second instance
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-0.6B \
--port 8101 \
--host 0.0.0.0 \
--additional-config='{"enable_cpu_binding":true}' \
--gpu-memory-utilization 0.3 \
--max-num-seqs 1 \
--max-model-len 2048 \
--max-num-batched-tokens 2048 \
--no-enable-prefix-caching \
--enforce-eager
```
**Before this PR:**
```bash
# First instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2340388298034668 GiB
init_non_torch_memory: 0.3616676330566406 GiB
non_torch_memory_before_empty_cache: 0.3896217346191406 GiB
non_torch_memory_increase: 0.0279541015625 GiB
non_torch_memory_cleared_by_empty_cache: 0.3616676330566406 GiB
------------------------------------------------------------------
# Second instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2336344718933105 GiB
init_non_torch_memory: 18.37220001220703 GiB
non_torch_memory_before_empty_cache: 18.399906158447266 GiB
non_torch_memory_increase: 0.02754974365234375 GiB
non_torch_memory_cleared_by_empty_cache: 18.372356414794922 GiB
------------------------------------------------------------------
# available_kv_cache_memory = requested_memory - non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
Available KV cache memory: -1.32 GiB
```
**After this PR:**
```bash
# First instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.2340540885925293 GiB
init_non_torch_memory: 0.36182403564453125 GiB
non_torch_memory_before_empty_cache: 0.38979339599609375 GiB
non_torch_memory_increase: 0.0279693603515625 GiB
non_torch_memory_cleared_by_empty_cache: 0.0 GiB
------------------------------------------------------------------
# Second instance:
------------------------------------------------------------------
requested_memory: 18.287109375 GiB
non_kv_cache_memory: 1.233344554901123 GiB
init_non_torch_memory: 18.74309539794922 GiB
non_torch_memory_before_empty_cache: 18.770355224609375 GiB
non_torch_memory_increase: 0.02725982666015625 GiB
non_torch_memory_cleared_by_empty_cache: 0.0 GiB
------------------------------------------------------------------
# available_kv_cache_memory = requested_memory - non_kv_cache_memory - non_torch_memory_cleared_by_empty_cache
Available KV cache memory: 17.05 GiB
```
- vLLM version: v0.17.0
- vLLM main:
4497431df6
---------
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
94 lines
3.8 KiB
Python
94 lines
3.8 KiB
Python
#
|
||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||
# This file is a part of the vllm-ascend project.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
|
||
"""
|
||
Two VllmRunner instances are nested so that the first instance's worker
|
||
process is still holding NPU memory when the second instance's worker process
|
||
starts. Both instances must:
|
||
|
||
1. Initialize without raising any exception (no OOM during
|
||
determine_available_memory / KV-cache allocation).
|
||
2. Successfully complete a short generation request.
|
||
|
||
The model is Qwen/Qwen3-0.6B (~0.5 GiB weights) and gpu_memory_utilization
|
||
is set to 0.4 per instance so that two instances comfortably fit on a single
|
||
64 GiB Ascend 910B card while leaving enough headroom to avoid the
|
||
pre-fix negative-KV-cache condition.
|
||
"""
|
||
|
||
import os
|
||
|
||
from tests.e2e.conftest import VllmRunner
|
||
|
||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||
|
||
MODEL = "Qwen/Qwen3-0.6B"
|
||
_PROMPTS = ["Hello, my name is"]
|
||
_MAX_TOKENS = 5
|
||
# Use a low utilization so two instances fit side-by-side on one card:
|
||
# 2 × 0.4 × card_total ≤ card_total (holds for any card ≥ 1 GiB)
|
||
_GPU_MEM_UTIL = 0.4
|
||
_MAX_MODEL_LEN = 512
|
||
|
||
|
||
def test_two_instances_on_single_card() -> None:
|
||
"""
|
||
Regression test for PR #7427 (multi-instance OOM on single card).
|
||
|
||
Start a first vllm-ascend instance; while it is still running and holding
|
||
NPU memory, start a second instance with identical settings. Both must
|
||
initialize correctly and produce non-empty outputs.
|
||
|
||
Failure signature (pre-fix):
|
||
RuntimeError / ValueError during the second instance's init, or
|
||
"Available KV cache memory: -X.XX GiB" in the logs followed by
|
||
zero KV blocks being allocated.
|
||
"""
|
||
# ── First instance ──────────────────────────────────────────────────
|
||
with VllmRunner(
|
||
MODEL,
|
||
max_model_len=_MAX_MODEL_LEN,
|
||
gpu_memory_utilization=_GPU_MEM_UTIL,
|
||
enforce_eager=True,
|
||
) as runner1:
|
||
# ── Second instance starts while first is still alive ────────────
|
||
# This is the exact scenario from PR #7427: the second worker process
|
||
# sees a reduced init_snapshot.free_memory because the first instance's
|
||
# worker is still holding NPU memory.
|
||
with VllmRunner(
|
||
MODEL,
|
||
max_model_len=_MAX_MODEL_LEN,
|
||
gpu_memory_utilization=_GPU_MEM_UTIL,
|
||
enforce_eager=True,
|
||
) as runner2:
|
||
outputs2 = runner2.generate_greedy(_PROMPTS, max_tokens=_MAX_TOKENS)
|
||
|
||
outputs1 = runner1.generate_greedy(_PROMPTS, max_tokens=_MAX_TOKENS)
|
||
|
||
# ── Assertions ───────────────────────────────────────────────────────
|
||
assert outputs1, "First instance produced no outputs"
|
||
assert outputs2, "Second instance produced no outputs"
|
||
|
||
_, text1 = outputs1[0]
|
||
_, text2 = outputs2[0]
|
||
|
||
assert text1, "First instance output text is empty — model may have failed to run"
|
||
assert text2, (
|
||
"Second instance output text is empty — "
|
||
"KV cache may have been allocated with zero blocks (pre-fix OOM regression)"
|
||
)
|