[CI] Add wait logic for each individual case (#6036)

### What this PR does / why we need it?
Wait until the NPU memory is clean
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
Li Wang
2026-01-20 21:05:44 +08:00
committed by GitHub
parent 750c06c78a
commit 8cf1e8d8a7
3 changed files with 84 additions and 3 deletions

View File

@@ -18,9 +18,11 @@
#
import contextlib
import functools
import gc
import json
import logging
import multiprocessing
import os
import shlex
import subprocess
@@ -78,6 +80,77 @@ logger = logging.getLogger(__name__)
_TEST_DIR = os.path.dirname(__file__)
def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
import torch_npu # type: ignore
# We can try to clean up memory in this subprocess, though it mostly affects this process.
# But if there are any lingering contexts in this process (unlikely for a fresh spawn), it helps.
gc.collect()
torch.npu.empty_cache()
_, total_npu_memory = torch.npu.mem_get_info()
start_time = time.time()
while True:
free_bytes, _ = torch.npu.mem_get_info()
if free_bytes / total_npu_memory >= target_free_percentage:
print(f'check_npu_memory_worker: npu free memory decreased target value.')
return # Success
elapsed = time.time() - start_time
if elapsed > max_wait_seconds:
# Print to stderr so it's visible in test logs even if captured
print(
f"Timeout: NPU memory free size did not reach "
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds.",
file=sys.stderr
)
sys.exit(1) # Failure
print(
f"Waiting for NPU memory to be free: "
f"{free_bytes / 1024**3:.2f} GB available, "
f"Elapsed time: {elapsed:.2f} s."
)
# Try to clean up
gc.collect()
torch.npu.empty_cache()
time.sleep(1)
def wait_until_npu_memory_free(target_free_percentage: float = 0.5, max_wait_seconds: float = 50):
"""Decorator to wait until the NPU memory free size is above target_free_percentage.
Args:
target_free_percentage (float): Target free memory percentage of total.
max_wait_seconds (float): Maximum wait time in seconds.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Clean up non-NPU resources in the main process
cleanup_dist_env_and_memory()
# Use a spawned subprocess to check NPU memory to avoid initializing NPU in the main process
ctx = multiprocessing.get_context("spawn")
p = ctx.Process(
target=_check_npu_memory_worker,
args=(target_free_percentage, max_wait_seconds)
)
p.start()
p.join()
if p.exitcode != 0:
raise TimeoutError(
f"Timeout: NPU memory free size did not reach "
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds."
)
return func(*args, **kwargs)
return wrapper
return decorator
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
@@ -87,8 +160,13 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
# Only clean NPU cache if NPU is already initialized/available in this process.
# This prevents accidental initialization of NPU context in the main process,
# which would break subsequent forks.
if hasattr(torch, "npu") and torch.npu.is_initialized():
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
class RemoteOpenAIServer: