[CI] Add wait logic for each individual case (#6036)

### What this PR does / why we need it?
Wait until the NPU memory is clean
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
Li Wang
2026-01-20 21:05:44 +08:00
committed by GitHub
parent 750c06c78a
commit 8cf1e8d8a7
3 changed files with 84 additions and 3 deletions

View File

@@ -18,9 +18,11 @@
#
import contextlib
import functools
import gc
import json
import logging
import multiprocessing
import os
import shlex
import subprocess
@@ -78,6 +80,77 @@ logger = logging.getLogger(__name__)
_TEST_DIR = os.path.dirname(__file__)
def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
import torch_npu # type: ignore
# We can try to clean up memory in this subprocess, though it mostly affects this process.
# But if there are any lingering contexts in this process (unlikely for a fresh spawn), it helps.
gc.collect()
torch.npu.empty_cache()
_, total_npu_memory = torch.npu.mem_get_info()
start_time = time.time()
while True:
free_bytes, _ = torch.npu.mem_get_info()
if free_bytes / total_npu_memory >= target_free_percentage:
print(f'check_npu_memory_worker: npu free memory decreased target value.')
return # Success
elapsed = time.time() - start_time
if elapsed > max_wait_seconds:
# Print to stderr so it's visible in test logs even if captured
print(
f"Timeout: NPU memory free size did not reach "
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds.",
file=sys.stderr
)
sys.exit(1) # Failure
print(
f"Waiting for NPU memory to be free: "
f"{free_bytes / 1024**3:.2f} GB available, "
f"Elapsed time: {elapsed:.2f} s."
)
# Try to clean up
gc.collect()
torch.npu.empty_cache()
time.sleep(1)
def wait_until_npu_memory_free(target_free_percentage: float = 0.5, max_wait_seconds: float = 50):
"""Decorator to wait until the NPU memory free size is above target_free_percentage.
Args:
target_free_percentage (float): Target free memory percentage of total.
max_wait_seconds (float): Maximum wait time in seconds.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Clean up non-NPU resources in the main process
cleanup_dist_env_and_memory()
# Use a spawned subprocess to check NPU memory to avoid initializing NPU in the main process
ctx = multiprocessing.get_context("spawn")
p = ctx.Process(
target=_check_npu_memory_worker,
args=(target_free_percentage, max_wait_seconds)
)
p.start()
p.join()
if p.exitcode != 0:
raise TimeoutError(
f"Timeout: NPU memory free size did not reach "
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds."
)
return func(*args, **kwargs)
return wrapper
return decorator
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
@@ -87,8 +160,13 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
# Only clean NPU cache if NPU is already initialized/available in this process.
# This prevents accidental initialization of NPU context in the main process,
# which would break subsequent forks.
if hasattr(torch, "npu") and torch.npu.is_initialized():
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
class RemoteOpenAIServer:

View File

@@ -26,6 +26,7 @@ import torch
from vllm.utils.network_utils import get_open_port
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
from tests.e2e.conftest import wait_until_npu_memory_free
MODELS = [
# Offline data parallel mode will be not supported/useful for dense models
@@ -137,6 +138,7 @@ def _run_worker_process(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [4, 36])
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"})
@wait_until_npu_memory_free(target_free_percentage=0.6)
def test_models_aclgraph_capture_replay_metrics_dp2(
model: str,
max_tokens: int,

View File

@@ -20,7 +20,7 @@ import os
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free
os.environ["HCCL_BUFFSIZE"] = "768"
@@ -126,6 +126,7 @@ def test_models_pcp_dcp_piece_wise():
runner.model.generate(prompts, sampling_params)
@wait_until_npu_memory_free()
def test_pcp_basic():
prompts = [
"The capital of France is", "Hello, my name is Tom, I am",