From 11122080523853960b6f2605dd9c308f53bb9e8a Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 7 Jan 2026 09:25:55 +0800 Subject: [PATCH] [Refactor] Cleanup platform (#5566) ### What this PR does / why we need it? 1. add `COMPILATION_PASS_KEY` constant 2. clean up useless platform interface `empty_cache`, `synchronize`, `mem_get_info`, `clear_npu_memory` 3. rename `CUSTOM_OP_REGISTERED` to `_CUSTOM_OP_REGISTERED` 4. remove uesless env `VLLM_ENABLE_CUDAGRAPH_GC` NPUPlatform is the interface called by vLLM. Do not call it inner vllm-ascend. ### Does this PR introduce _any_ user-facing change? This PR is just a cleanup. All CI should pass. ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/7157596103666ee7ccb7008acee8bff8a8ff1731 Signed-off-by: wangxiyuan --- tests/ut/attention/test_mla_cp.py | 1 - tests/ut/device_allocator/test_camem.py | 15 ++- tests/ut/test_platform.py | 109 ------------------ tests/ut/worker/test_worker_v1.py | 82 +++++-------- vllm_ascend/compilation/compiler_interface.py | 3 +- vllm_ascend/device_allocator/camem.py | 11 +- vllm_ascend/platform.py | 41 ++----- vllm_ascend/utils.py | 1 + vllm_ascend/worker/worker.py | 33 +++--- 9 files changed, 79 insertions(+), 217 deletions(-) diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index 10358547..74d9ecbc 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -879,7 +879,6 @@ class TestAscendMLAImpl(TestBase): B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8] test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)] for test_case in test_cases: - print(test_case) self.impl.dcp_size = test_case[0] self.impl.pcp_size = test_case[1] mock_dcp.world_size = test_case[0] diff --git a/tests/ut/device_allocator/test_camem.py b/tests/ut/device_allocator/test_camem.py index ec500e73..41953934 100644 --- a/tests/ut/device_allocator/test_camem.py +++ b/tests/ut/device_allocator/test_camem.py @@ -128,10 +128,17 @@ class TestCaMem(PytestBase): 2000: data2, } - # mock is_pin_memory_available, return False as some machine only has cpu - with patch( - "vllm_ascend.device_allocator.camem.NPUPlatform.is_pin_memory_available", - return_value=False): + # Mock torch.empty to force pin_memory=False + original_torch_empty = torch.empty + + def mock_torch_empty(*args, **kwargs): + # If pin_memory was explicitly set to True, change it to False + if 'pin_memory' in kwargs and kwargs['pin_memory'] is True: + kwargs['pin_memory'] = False + return original_torch_empty(*args, **kwargs) + + with patch("vllm_ascend.device_allocator.camem.torch.empty", + side_effect=mock_torch_empty): allocator.sleep(offload_tags="tag1") # only offload tag1, other tag2 call unmap_and_release diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index d7a5ecba..6f2eeec1 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -120,115 +120,6 @@ class TestNPUPlatform(TestBase): self.assertIsNone(self.platform.inference_mode()) mock_inference_mode.assert_called_once() - @patch("torch.npu.set_device") - def test_set_device_normal(self, mock_set_device): - device = torch.device("npu:0") - self.platform.set_device(device) - mock_set_device.assert_called_once_with(device) - - @patch("torch.npu.set_device", - side_effect=RuntimeError("Device not available")) - def test_set_device_failure(self, mock_set_device): - device = torch.device("npu:0") - with self.assertRaises(RuntimeError): - self.platform.set_device(device) - mock_set_device.assert_called_once_with(device) - - @patch("torch.npu.empty_cache") - def test_empty_cache_normal(self, mock_empty_cache): - self.platform.empty_cache() - mock_empty_cache.assert_called_once() - - @patch("torch.npu.empty_cache", - side_effect=RuntimeError("Cache clearing failed")) - def test_empty_cache_failure(self, mock_empty_cache): - with self.assertRaises(RuntimeError): - self.platform.empty_cache() - mock_empty_cache.assert_called_once() - - @patch("torch.npu.synchronize") - def test_synchronize_normal(self, mock_synchronize): - self.platform.synchronize() - mock_synchronize.assert_called_once() - - @patch("torch.npu.synchronize", - side_effect=RuntimeError("Synchronization failed")) - def test_synchronize_failure(self, mock_synchronize): - with self.assertRaises(RuntimeError): - self.platform.synchronize() - mock_synchronize.assert_called_once() - - @patch("torch.npu.mem_get_info") - def test_mem_get_info_normal(self, mock_mem_get_info): - free_memory_size = 1024 - total_memory_size = 2048 - memory_info = (free_memory_size, total_memory_size) - mock_mem_get_info.return_value = memory_info - result = self.platform.mem_get_info() - self.assertIsInstance(result, tuple) - self.assertEqual(len(result), 2) - self.assertEqual(result, memory_info) - mock_mem_get_info.assert_called_once() - - @patch("torch.npu.mem_get_info", - side_effect=RuntimeError("NPU not available")) - def test_mem_get_info_failure(self, mock_mem_get_info): - with self.assertRaises(RuntimeError): - self.platform.mem_get_info() - mock_mem_get_info.assert_called_once() - - @patch("gc.collect") - @patch("torch.npu.empty_cache") - @patch("torch.npu.reset_peak_memory_stats") - def test_clear_npu_memory_normal(self, mock_reset_stats, mock_empty_cache, - mock_gc_collect): - self.platform.clear_npu_memory() - - mock_gc_collect.assert_called_once() - mock_empty_cache.assert_called_once() - mock_reset_stats.assert_called_once() - - @patch("gc.collect", side_effect=Exception("GC failed")) - @patch("torch.npu.empty_cache") - @patch("torch.npu.reset_peak_memory_stats") - def test_clear_npu_memory_gc_collect_failure(self, mock_reset_stats, - mock_empty_cache, - mock_gc_collect): - with self.assertRaises(Exception): - self.platform.clear_npu_memory() - - mock_gc_collect.assert_called_once() - mock_empty_cache.assert_not_called() - mock_reset_stats.assert_not_called() - - @patch("gc.collect") - @patch("torch.npu.empty_cache", - side_effect=RuntimeError("Cache clear failed")) - @patch("torch.npu.reset_peak_memory_stats") - def test_clear_npu_memory_empty_cache_failure(self, mock_reset_stats, - mock_empty_cache, - mock_gc_collect): - with self.assertRaises(RuntimeError): - self.platform.clear_npu_memory() - - mock_gc_collect.assert_called_once() - mock_empty_cache.assert_called_once() - mock_reset_stats.assert_not_called() - - @patch("gc.collect") - @patch("torch.npu.empty_cache") - @patch("torch.npu.reset_peak_memory_stats", - side_effect=RuntimeError("Reset failed")) - def test_clear_npu_memory_reset_stats_failure(self, mock_reset_stats, - mock_empty_cache, - mock_gc_collect): - with self.assertRaises(RuntimeError): - self.platform.clear_npu_memory() - - mock_gc_collect.assert_called_once() - mock_empty_cache.assert_called_once() - mock_reset_stats.assert_called_once() - @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.utils.update_aclgraph_sizes") @patch('vllm_ascend.utils.get_ascend_device_type', diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index e3f6dd14..49d6c86e 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -238,15 +238,18 @@ class TestNPUWorker(TestBase): @patch( "vllm_ascend.worker.worker.NPUWorker._init_worker_distributed_environment" ) - @patch("vllm_ascend.worker.worker.NPUPlatform") @patch("vllm_ascend.worker.worker.init_device_properties_triton") - def test_init_device(self, mock_init_triton, mock_platform, + @patch("torch.npu.set_device") + @patch("torch.npu.empty_cache") + @patch("torch.npu.mem_get_info") + def test_init_device(self, mock_mem_get_info, mock_set_device, + mock_empty_cache, mock_init_triton, mock_init_dist_env): """Test _init_device method""" from vllm_ascend.worker.worker import NPUWorker # Setup mock - mock_platform.mem_get_info.return_value = (1000, 2000) + mock_mem_get_info.return_value = (1000, 2000) # Create worker mock with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None): @@ -256,21 +259,13 @@ class TestNPUWorker(TestBase): worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 worker.parallel_config.data_parallel_size = 1 - worker.model_config.seed = 42 # Test _init_device result = worker._init_device() - # Verify NPUPlatform.set_device is called - mock_platform.set_device.assert_called_once() # Verify the parameter passed to set_device is a torch.device object - call_args = mock_platform.set_device.call_args[0][0] - self.assertEqual(str(call_args), "npu:1") - - mock_platform.empty_cache.assert_called_once() - mock_platform.seed_everything.assert_called_once_with(42) - mock_platform.mem_get_info.assert_called_once( + mock_mem_get_info.assert_called_once( ) # Called once in _init_device method mock_init_dist_env.assert_called_once( ) # Verify distributed initialization is called @@ -548,9 +543,8 @@ class TestNPUWorker(TestBase): # Verify returns None (empty string is considered false) self.assertIsNone(result) - @patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory") - @patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache") - @patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info") + @patch("torch.npu.reset_peak_memory_stats") + @patch("torch.npu.empty_cache") @patch("torch_npu.npu.memory_stats") @patch("torch_npu.npu.mem_get_info") @patch("vllm_ascend.worker.worker.logger") @@ -559,15 +553,14 @@ class TestNPUWorker(TestBase): mock_logger, mock_torch_mem_get_info, mock_torch_memory_stats, - mock_platform_mem_get_info, - mock_platform_empty_cache, - mock_platform_clear_npu_memory, + mock_torch_empty_cache, + mock_torch_reset_peak_memory_stats, ): """Test determine_available_memory normal case (no non-torch memory allocation)""" from vllm_ascend.worker.worker import NPUWorker # Setup mock - test case without non-torch memory allocation - mock_platform_mem_get_info.side_effect = [ + mock_torch_mem_get_info.side_effect = [ (8000, 10000), # 1st call: before profile execution (7000, 10000), # 2nd call: after profile execution ] @@ -606,10 +599,8 @@ class TestNPUWorker(TestBase): result = worker.determine_available_memory() # Verify call count and order - mock_platform_clear_npu_memory.assert_called_once() - self.assertEqual(mock_platform_mem_get_info.call_count, 2) + self.assertEqual(mock_torch_mem_get_info.call_count, 4) worker.model_runner.profile_run.assert_called_once() - mock_platform_empty_cache.assert_called_once() # Verify calculation result with race condition simulation # Calculation logic: @@ -629,24 +620,22 @@ class TestNPUWorker(TestBase): # Verify log output mock_logger.info.assert_called_once() - @patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory") - @patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache") - @patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info") + @patch("torch.npu.reset_peak_memory_stats") + @patch("torch.npu.empty_cache") @patch("torch_npu.npu.memory_stats") @patch("torch_npu.npu.mem_get_info") def test_determine_available_memory_with_non_torch_allocations( self, mock_torch_mem_get_info, mock_torch_memory_stats, - mock_platform_mem_get_info, - mock_platform_empty_cache, - mock_platform_clear_npu_memory, + mock_torch_empty_cache, + mock_torch_reset_peak_memory_stats, ): """Test determine_available_memory with significant non-torch memory allocation""" from vllm_ascend.worker.worker import NPUWorker # Setup mock - test case with large non-torch memory allocation - mock_platform_mem_get_info.side_effect = [ + mock_torch_mem_get_info.side_effect = [ (8000, 10000), # 1st call (7000, 10000), # 2nd call ] @@ -695,15 +684,17 @@ class TestNPUWorker(TestBase): expected_result = max(0, int(10000 * 0.9 - 5500)) self.assertEqual(result, expected_result) - @patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory") - @patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info") + @patch("torch.npu.mem_get_info") + @patch("torch.npu.reset_peak_memory_stats") + @patch("torch.npu.empty_cache") def test_determine_available_memory_memory_profiling_error( - self, mock_platform_mem_get_info, mock_platform_clear_npu_memory): + self, mock_torch_empty_cache, mock_torch_reset_peak_memory_stats, + mock_torch_mem_get_info): """Test determine_available_memory throws exception on memory profiling error""" from vllm_ascend.worker.worker import NPUWorker # Setup mock: initial memory less than current free memory (error case) - mock_platform_mem_get_info.side_effect = [ + mock_torch_mem_get_info.side_effect = [ (8000, 10000), # 1st call (9000, 10000), # 2nd call: free memory increased instead ] @@ -722,24 +713,22 @@ class TestNPUWorker(TestBase): self.assertIn("Error in memory profiling", str(cm.exception)) - @patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory") - @patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache") - @patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info") + @patch("torch.npu.reset_peak_memory_stats") + @patch("torch.npu.empty_cache") @patch("torch_npu.npu.memory_stats") @patch("torch_npu.npu.mem_get_info") def test_determine_available_memory_negative_result( self, mock_torch_mem_get_info, mock_torch_memory_stats, - mock_platform_mem_get_info, - mock_platform_empty_cache, - mock_platform_clear_npu_memory, + mock_torch_empty_cache, + mock_torch_reset_peak_memory_stats, ): """Test determine_available_memory returns 0 when result is negative""" from vllm_ascend.worker.worker import NPUWorker # Setup mock: high peak memory causes negative available memory - mock_platform_mem_get_info.side_effect = [ + mock_torch_mem_get_info.side_effect = [ (8000, 10000), # 1st call (3000, 10000), # 2nd call ] @@ -989,12 +978,10 @@ class TestNPUWorker(TestBase): self.assertIn("Sleep mode can only be", str(cm.exception)) - @patch("vllm_ascend.worker.worker.NPUPlatform.seed_everything") @patch("vllm_ascend.worker.worker.logger") @patch("vllm_ascend.worker.worker.NPUWorker._warm_up_atb") def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb, - mock_logger, - mock_seed_everything): + mock_logger): """Test compile_or_warm_up_model method - eager mode""" from vllm_ascend.worker.worker import NPUWorker @@ -1032,17 +1019,13 @@ class TestNPUWorker(TestBase): # Verify log output self.assertEqual(mock_logger.info.call_count, 4) - # Verify seed setting - mock_seed_everything.assert_called_once_with(12345) - # Verify atb warm up mock_warm_up_atb.assert_called_once() - @patch("vllm_ascend.worker.worker.NPUPlatform.seed_everything") @patch("vllm_ascend.worker.worker.logger") @patch("vllm_ascend.worker.worker.NPUWorker._warm_up_atb") def test_compile_or_warm_up_model_with_graph_capture( - self, mock_warm_up_atb, mock_logger, mock_seed_everything): + self, mock_warm_up_atb, mock_logger): """Test compile_or_warm_up_model method - with graph capture enabled""" from vllm_ascend.worker.worker import NPUWorker @@ -1072,9 +1055,6 @@ class TestNPUWorker(TestBase): # Should call capture_model in non-eager mode worker.model_runner.capture_model.assert_called_once() - # Verify seed setting - mock_seed_everything.assert_called_once_with(67890) - # Verify atb warm up mock_warm_up_atb.assert_called_once() diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 1c706806..24b85e87 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -28,6 +28,7 @@ from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import COMPILATION_PASS_KEY def compile_fx(graph: GraphModule, example_inputs: list, @@ -51,7 +52,7 @@ def fusion_pass_compile( ) -> tuple[Optional[Callable], Optional[Any]]: def compile_inner(graph, example_inputs): - current_pass_manager = compiler_config["graph_fusion_manager"] + current_pass_manager = compiler_config[COMPILATION_PASS_KEY] graph = current_pass_manager(graph, runtime_shape) return graph diff --git a/vllm_ascend/device_allocator/camem.py b/vllm_ascend/device_allocator/camem.py index 1bd97ab4..1054263e 100644 --- a/vllm_ascend/device_allocator/camem.py +++ b/vllm_ascend/device_allocator/camem.py @@ -25,8 +25,6 @@ import torch from acl.rt import memcpy # type: ignore # noqa: F401 from vllm.logger import logger -from vllm_ascend.platform import NPUPlatform - def find_loaded_library(lib_name) -> Optional[str]: """ @@ -196,11 +194,10 @@ class CaMemAllocator: handle = data.handle if data.tag in offload_tags: size_in_bytes = handle[1] - cpu_backup_tensor = torch.empty( - size_in_bytes, - dtype=torch.uint8, - device='cpu', - pin_memory=NPUPlatform.is_pin_memory_available()) + cpu_backup_tensor = torch.empty(size_in_bytes, + dtype=torch.uint8, + device='cpu', + pin_memory=True) cpu_ptr = cpu_backup_tensor.data_ptr() ACL_MEMCPY_DEVICE_TO_HOST = 2 dest_max = cpu_ptr + size_in_bytes * 2 diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 4a60d3d5..9bdf62aa 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -15,9 +15,8 @@ # This file is a part of the vllm-ascend project. # -import gc import os -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional from uuid import uuid4 import torch @@ -26,18 +25,16 @@ from vllm.platforms import Platform, PlatformEnum # todo: please remove it when solve cuda hard code in vllm os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1" -# todo: please remove it when support controls garbage collection during CUDA graph capture. -os.environ["VLLM_ENABLE_CUDAGRAPH_GC"] = "1" from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.utils import refresh_block_size # isort: off from vllm_ascend.utils import ( - ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType, - enable_sp, get_ascend_device_type, is_vl_model, update_aclgraph_sizes, - update_cudagraph_capture_sizes, update_default_aclgraph_sizes, - check_kv_extra_config) + ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, + COMPILATION_PASS_KEY, AscendDeviceType, enable_sp, get_ascend_device_type, + is_vl_model, update_aclgraph_sizes, update_cudagraph_capture_sizes, + update_default_aclgraph_sizes, check_kv_extra_config) if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -47,7 +44,7 @@ else: VllmConfig = None FlexibleArgumentParser = None -CUSTOM_OP_REGISTERED = False +_CUSTOM_OP_REGISTERED = False class NPUPlatform(Platform): @@ -74,7 +71,7 @@ class NPUPlatform(Platform): It is a parameter of inductor_config used to register custom passes. Currently, we only use Inductor's 'pattern matcher' functionality, so we define our own pass_key. """ - return "graph_fusion_manager" + return COMPILATION_PASS_KEY @classmethod def get_pass_manager_cls(cls) -> str: @@ -131,24 +128,6 @@ class NPUPlatform(Platform): def set_device(cls, device: torch.device): torch.npu.set_device(device) - @classmethod - def empty_cache(cls): - torch.npu.empty_cache() - - @classmethod - def synchronize(cls): - torch.npu.synchronize() - - @classmethod - def mem_get_info(cls) -> Tuple[int, int]: - return torch.npu.mem_get_info() - - @classmethod - def clear_npu_memory(cls): - gc.collect() - torch.npu.empty_cache() - torch.npu.reset_peak_memory_stats() - @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # initialize ascend config from vllm additional_config @@ -351,8 +330,8 @@ class NPUPlatform(Platform): # from vllm_ascend.utils import enable_custom_op # enable_custom_op() # set custom ops path - global CUSTOM_OP_REGISTERED - if CUSTOM_OP_REGISTERED: + global _CUSTOM_OP_REGISTERED + if _CUSTOM_OP_REGISTERED: return CUR_DIR = os.path.dirname(os.path.realpath(__file__)) CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors", @@ -365,7 +344,7 @@ class NPUPlatform(Platform): "ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}" else: os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH - CUSTOM_OP_REGISTERED = True + _CUSTOM_OP_REGISTERED = True @classmethod def get_attn_backend_cls(cls, selected_backend, attn_selector_config): diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d9d92754..711fbbd1 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -41,6 +41,7 @@ if TYPE_CHECKING: else: VllmConfig = None +COMPILATION_PASS_KEY = "graph_fusion_manager" ASCEND_QUANTIZATION_METHOD = "ascend" COMPRESSED_TENSORS_METHOD = "compressed-tensors" SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"] diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 2da2551e..2529a298 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -18,6 +18,7 @@ # import copy +import gc from types import NoneType from typing import Optional @@ -55,15 +56,19 @@ from vllm_ascend.cpu_binding import bind_cpus from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton -from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type, enable_sp, get_ascend_device_type, - register_ascend_customop) + register_ascend_customop, vllm_version_is) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402 +if vllm_version_is("0.13.0"): + from vllm.model_executor.utils import set_random_seed +else: + from vllm.utils.torch_utils import set_random_seed + torch_non_c_binding_in_graph_functions_npu = dict.fromkeys( ["torch.npu.current_stream"], TorchInGraphFunctionVariable, @@ -147,7 +152,7 @@ class NPUWorker(WorkerBase): self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER def sleep(self, level: int = 1) -> None: - free_bytes_before_sleep = NPUPlatform.mem_get_info()[0] + free_bytes_before_sleep = torch.npu.mem_get_info()[0] # Save the buffers before level 2 sleep if level == 2: model = self.model_runner.model @@ -157,7 +162,7 @@ class NPUWorker(WorkerBase): } allocator = CaMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) - free_bytes_after_sleep, total = NPUPlatform.mem_get_info() + free_bytes_after_sleep, total = torch.npu.mem_get_info() freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep used_bytes = total - free_bytes_after_sleep assert freed_bytes >= 0, "Memory usage increased after sleeping." @@ -210,8 +215,8 @@ class NPUWorker(WorkerBase): def _init_device(self): device = torch.device(f"npu:{self.local_rank}") - NPUPlatform.set_device(device) - NPUPlatform.empty_cache() + torch.npu.set_device(device) + torch.npu.empty_cache() if (self.parallel_config.data_parallel_size > 1 and self.parallel_config.data_parallel_size_local > 0 @@ -226,11 +231,11 @@ class NPUWorker(WorkerBase): f"be less than or equal to the number of visible devices " f"({visible_device_count}).") - self.init_npu_memory = NPUPlatform.mem_get_info()[0] + self.init_npu_memory = torch.npu.mem_get_info()[0] # Initialize the distributed environment. self._init_worker_distributed_environment() # Set random seed. - NPUPlatform.seed_everything(self.model_config.seed) + set_random_seed(self.model_config.seed) # Initialize device properties used by triton kernels. init_device_properties_triton() return device @@ -258,16 +263,18 @@ class NPUWorker(WorkerBase): def determine_available_memory(self) -> int: # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. - NPUPlatform.clear_npu_memory() + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - _, total_npu_memory = NPUPlatform.mem_get_info() + _, total_npu_memory = torch.npu.mem_get_info() self.model_runner.profile_run() # Calculate the number of blocks that can be allocated with the # profiled peak memory. - free_npu_memory, _ = NPUPlatform.mem_get_info() + free_npu_memory, _ = torch.npu.mem_get_info() # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. assert self.init_npu_memory > free_npu_memory, ( @@ -280,7 +287,7 @@ class NPUWorker(WorkerBase): peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"] # TODO: don`t need impl this func after empty_cache in # Worker.determine_num_available_blocks() unified` - NPUPlatform.empty_cache() + torch.npu.empty_cache() torch_allocated_bytes = torch_npu.npu.memory_stats( )["allocated_bytes.all.current"] total_allocated_bytes = torch_npu.npu.mem_get_info( @@ -389,7 +396,7 @@ class NPUWorker(WorkerBase): self._warm_up_atb() # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. - NPUPlatform.seed_everything(self.model_config.seed) + set_random_seed(self.model_config.seed) def _warm_up_atb(self): x = torch.rand((2, 4), dtype=torch.float16).npu()