From 8cf1e8d8a7ca278cbb75896f39e9645c4b5b2f74 Mon Sep 17 00:00:00 2001 From: Li Wang Date: Tue, 20 Jan 2026 21:05:44 +0800 Subject: [PATCH] [CI] Add wait logic for each individual case (#6036) ### What this PR does / why we need it? Wait until the NPU memory is clean ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: wangli Signed-off-by: leo-pony Co-authored-by: leo-pony --- tests/e2e/conftest.py | 82 ++++++++++++++++++- .../2-cards/test_aclgraph_capture_replay.py | 2 + .../4-cards/long_sequence/test_basic.py | 3 +- 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 6f3c4fbe..1d7fd9ec 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -18,9 +18,11 @@ # import contextlib +import functools import gc import json import logging +import multiprocessing import os import shlex import subprocess @@ -78,6 +80,77 @@ logger = logging.getLogger(__name__) _TEST_DIR = os.path.dirname(__file__) +def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float): + import torch_npu # type: ignore + + # We can try to clean up memory in this subprocess, though it mostly affects this process. + # But if there are any lingering contexts in this process (unlikely for a fresh spawn), it helps. + gc.collect() + torch.npu.empty_cache() + + _, total_npu_memory = torch.npu.mem_get_info() + start_time = time.time() + + while True: + free_bytes, _ = torch.npu.mem_get_info() + if free_bytes / total_npu_memory >= target_free_percentage: + print(f'check_npu_memory_worker: npu free memory decreased target value.') + return # Success + + elapsed = time.time() - start_time + if elapsed > max_wait_seconds: + # Print to stderr so it's visible in test logs even if captured + print( + f"Timeout: NPU memory free size did not reach " + f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds.", + file=sys.stderr + ) + sys.exit(1) # Failure + + print( + f"Waiting for NPU memory to be free: " + f"{free_bytes / 1024**3:.2f} GB available, " + f"Elapsed time: {elapsed:.2f} s." + ) + # Try to clean up + gc.collect() + torch.npu.empty_cache() + time.sleep(1) + + +def wait_until_npu_memory_free(target_free_percentage: float = 0.5, max_wait_seconds: float = 50): + """Decorator to wait until the NPU memory free size is above target_free_percentage. + + Args: + target_free_percentage (float): Target free memory percentage of total. + max_wait_seconds (float): Maximum wait time in seconds. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Clean up non-NPU resources in the main process + cleanup_dist_env_and_memory() + + # Use a spawned subprocess to check NPU memory to avoid initializing NPU in the main process + ctx = multiprocessing.get_context("spawn") + p = ctx.Process( + target=_check_npu_memory_worker, + args=(target_free_percentage, max_wait_seconds) + ) + p.start() + p.join() + + if p.exitcode != 0: + raise TimeoutError( + f"Timeout: NPU memory free size did not reach " + f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds." + ) + + return func(*args, **kwargs) + return wrapper + return decorator + + def cleanup_dist_env_and_memory(shutdown_ray: bool = False): destroy_model_parallel() destroy_distributed_environment() @@ -87,8 +160,13 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): import ray # Lazy import Ray ray.shutdown() gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() + + # Only clean NPU cache if NPU is already initialized/available in this process. + # This prevents accidental initialization of NPU context in the main process, + # which would break subsequent forks. + if hasattr(torch, "npu") and torch.npu.is_initialized(): + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() class RemoteOpenAIServer: diff --git a/tests/e2e/multicard/2-cards/test_aclgraph_capture_replay.py b/tests/e2e/multicard/2-cards/test_aclgraph_capture_replay.py index c4195bae..0c384606 100644 --- a/tests/e2e/multicard/2-cards/test_aclgraph_capture_replay.py +++ b/tests/e2e/multicard/2-cards/test_aclgraph_capture_replay.py @@ -26,6 +26,7 @@ import torch from vllm.utils.network_utils import get_open_port from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type +from tests.e2e.conftest import wait_until_npu_memory_free MODELS = [ # Offline data parallel mode will be not supported/useful for dense models @@ -137,6 +138,7 @@ def _run_worker_process( @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [4, 36]) @patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1"}) +@wait_until_npu_memory_free(target_free_percentage=0.6) def test_models_aclgraph_capture_replay_metrics_dp2( model: str, max_tokens: int, diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py index 92808252..1af86eb6 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_basic.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_basic.py @@ -20,7 +20,7 @@ import os from vllm import SamplingParams -from tests.e2e.conftest import VllmRunner +from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free os.environ["HCCL_BUFFSIZE"] = "768" @@ -126,6 +126,7 @@ def test_models_pcp_dcp_piece_wise(): runner.model.generate(prompts, sampling_params) +@wait_until_npu_memory_free() def test_pcp_basic(): prompts = [ "The capital of France is", "Hello, my name is Tom, I am",