[CI] Add wait logic for each individual case (#6036)
### What this PR does / why we need it?
Wait until the NPU memory is clean
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
@@ -18,9 +18,11 @@
|
||||
#
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
@@ -78,6 +80,77 @@ logger = logging.getLogger(__name__)
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
|
||||
import torch_npu # type: ignore
|
||||
|
||||
# We can try to clean up memory in this subprocess, though it mostly affects this process.
|
||||
# But if there are any lingering contexts in this process (unlikely for a fresh spawn), it helps.
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
_, total_npu_memory = torch.npu.mem_get_info()
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
free_bytes, _ = torch.npu.mem_get_info()
|
||||
if free_bytes / total_npu_memory >= target_free_percentage:
|
||||
print(f'check_npu_memory_worker: npu free memory decreased target value.')
|
||||
return # Success
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > max_wait_seconds:
|
||||
# Print to stderr so it's visible in test logs even if captured
|
||||
print(
|
||||
f"Timeout: NPU memory free size did not reach "
|
||||
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds.",
|
||||
file=sys.stderr
|
||||
)
|
||||
sys.exit(1) # Failure
|
||||
|
||||
print(
|
||||
f"Waiting for NPU memory to be free: "
|
||||
f"{free_bytes / 1024**3:.2f} GB available, "
|
||||
f"Elapsed time: {elapsed:.2f} s."
|
||||
)
|
||||
# Try to clean up
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def wait_until_npu_memory_free(target_free_percentage: float = 0.5, max_wait_seconds: float = 50):
|
||||
"""Decorator to wait until the NPU memory free size is above target_free_percentage.
|
||||
|
||||
Args:
|
||||
target_free_percentage (float): Target free memory percentage of total.
|
||||
max_wait_seconds (float): Maximum wait time in seconds.
|
||||
"""
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Clean up non-NPU resources in the main process
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Use a spawned subprocess to check NPU memory to avoid initializing NPU in the main process
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
p = ctx.Process(
|
||||
target=_check_npu_memory_worker,
|
||||
args=(target_free_percentage, max_wait_seconds)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
if p.exitcode != 0:
|
||||
raise TimeoutError(
|
||||
f"Timeout: NPU memory free size did not reach "
|
||||
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds."
|
||||
)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
@@ -87,8 +160,13 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
import ray # Lazy import Ray
|
||||
ray.shutdown()
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
# Only clean NPU cache if NPU is already initialized/available in this process.
|
||||
# This prevents accidental initialization of NPU context in the main process,
|
||||
# which would break subsequent forks.
|
||||
if hasattr(torch, "npu") and torch.npu.is_initialized():
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class RemoteOpenAIServer:
|
||||
|
||||
Reference in New Issue
Block a user