[Refactor] Cleanup platform (#5566)
### What this PR does / why we need it?
1. add `COMPILATION_PASS_KEY` constant
2. clean up useless platform interface `empty_cache`, `synchronize`,
`mem_get_info`, `clear_npu_memory`
3. rename `CUSTOM_OP_REGISTERED` to `_CUSTOM_OP_REGISTERED`
4. remove uesless env `VLLM_ENABLE_CUDAGRAPH_GC`
NPUPlatform is the interface called by vLLM. Do not call it inner
vllm-ascend.
### Does this PR introduce _any_ user-facing change?
This PR is just a cleanup. All CI should pass.
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
7157596103
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -879,7 +879,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8]
|
||||
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)]
|
||||
for test_case in test_cases:
|
||||
print(test_case)
|
||||
self.impl.dcp_size = test_case[0]
|
||||
self.impl.pcp_size = test_case[1]
|
||||
mock_dcp.world_size = test_case[0]
|
||||
|
||||
@@ -128,10 +128,17 @@ class TestCaMem(PytestBase):
|
||||
2000: data2,
|
||||
}
|
||||
|
||||
# mock is_pin_memory_available, return False as some machine only has cpu
|
||||
with patch(
|
||||
"vllm_ascend.device_allocator.camem.NPUPlatform.is_pin_memory_available",
|
||||
return_value=False):
|
||||
# Mock torch.empty to force pin_memory=False
|
||||
original_torch_empty = torch.empty
|
||||
|
||||
def mock_torch_empty(*args, **kwargs):
|
||||
# If pin_memory was explicitly set to True, change it to False
|
||||
if 'pin_memory' in kwargs and kwargs['pin_memory'] is True:
|
||||
kwargs['pin_memory'] = False
|
||||
return original_torch_empty(*args, **kwargs)
|
||||
|
||||
with patch("vllm_ascend.device_allocator.camem.torch.empty",
|
||||
side_effect=mock_torch_empty):
|
||||
allocator.sleep(offload_tags="tag1")
|
||||
|
||||
# only offload tag1, other tag2 call unmap_and_release
|
||||
|
||||
@@ -120,115 +120,6 @@ class TestNPUPlatform(TestBase):
|
||||
self.assertIsNone(self.platform.inference_mode())
|
||||
mock_inference_mode.assert_called_once()
|
||||
|
||||
@patch("torch.npu.set_device")
|
||||
def test_set_device_normal(self, mock_set_device):
|
||||
device = torch.device("npu:0")
|
||||
self.platform.set_device(device)
|
||||
mock_set_device.assert_called_once_with(device)
|
||||
|
||||
@patch("torch.npu.set_device",
|
||||
side_effect=RuntimeError("Device not available"))
|
||||
def test_set_device_failure(self, mock_set_device):
|
||||
device = torch.device("npu:0")
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.set_device(device)
|
||||
mock_set_device.assert_called_once_with(device)
|
||||
|
||||
@patch("torch.npu.empty_cache")
|
||||
def test_empty_cache_normal(self, mock_empty_cache):
|
||||
self.platform.empty_cache()
|
||||
mock_empty_cache.assert_called_once()
|
||||
|
||||
@patch("torch.npu.empty_cache",
|
||||
side_effect=RuntimeError("Cache clearing failed"))
|
||||
def test_empty_cache_failure(self, mock_empty_cache):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.empty_cache()
|
||||
mock_empty_cache.assert_called_once()
|
||||
|
||||
@patch("torch.npu.synchronize")
|
||||
def test_synchronize_normal(self, mock_synchronize):
|
||||
self.platform.synchronize()
|
||||
mock_synchronize.assert_called_once()
|
||||
|
||||
@patch("torch.npu.synchronize",
|
||||
side_effect=RuntimeError("Synchronization failed"))
|
||||
def test_synchronize_failure(self, mock_synchronize):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.synchronize()
|
||||
mock_synchronize.assert_called_once()
|
||||
|
||||
@patch("torch.npu.mem_get_info")
|
||||
def test_mem_get_info_normal(self, mock_mem_get_info):
|
||||
free_memory_size = 1024
|
||||
total_memory_size = 2048
|
||||
memory_info = (free_memory_size, total_memory_size)
|
||||
mock_mem_get_info.return_value = memory_info
|
||||
result = self.platform.mem_get_info()
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result, memory_info)
|
||||
mock_mem_get_info.assert_called_once()
|
||||
|
||||
@patch("torch.npu.mem_get_info",
|
||||
side_effect=RuntimeError("NPU not available"))
|
||||
def test_mem_get_info_failure(self, mock_mem_get_info):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.mem_get_info()
|
||||
mock_mem_get_info.assert_called_once()
|
||||
|
||||
@patch("gc.collect")
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
def test_clear_npu_memory_normal(self, mock_reset_stats, mock_empty_cache,
|
||||
mock_gc_collect):
|
||||
self.platform.clear_npu_memory()
|
||||
|
||||
mock_gc_collect.assert_called_once()
|
||||
mock_empty_cache.assert_called_once()
|
||||
mock_reset_stats.assert_called_once()
|
||||
|
||||
@patch("gc.collect", side_effect=Exception("GC failed"))
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
def test_clear_npu_memory_gc_collect_failure(self, mock_reset_stats,
|
||||
mock_empty_cache,
|
||||
mock_gc_collect):
|
||||
with self.assertRaises(Exception):
|
||||
self.platform.clear_npu_memory()
|
||||
|
||||
mock_gc_collect.assert_called_once()
|
||||
mock_empty_cache.assert_not_called()
|
||||
mock_reset_stats.assert_not_called()
|
||||
|
||||
@patch("gc.collect")
|
||||
@patch("torch.npu.empty_cache",
|
||||
side_effect=RuntimeError("Cache clear failed"))
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
def test_clear_npu_memory_empty_cache_failure(self, mock_reset_stats,
|
||||
mock_empty_cache,
|
||||
mock_gc_collect):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.clear_npu_memory()
|
||||
|
||||
mock_gc_collect.assert_called_once()
|
||||
mock_empty_cache.assert_called_once()
|
||||
mock_reset_stats.assert_not_called()
|
||||
|
||||
@patch("gc.collect")
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch.npu.reset_peak_memory_stats",
|
||||
side_effect=RuntimeError("Reset failed"))
|
||||
def test_clear_npu_memory_reset_stats_failure(self, mock_reset_stats,
|
||||
mock_empty_cache,
|
||||
mock_gc_collect):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.clear_npu_memory()
|
||||
|
||||
mock_gc_collect.assert_called_once()
|
||||
mock_empty_cache.assert_called_once()
|
||||
mock_reset_stats.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
|
||||
@@ -238,15 +238,18 @@ class TestNPUWorker(TestBase):
|
||||
@patch(
|
||||
"vllm_ascend.worker.worker.NPUWorker._init_worker_distributed_environment"
|
||||
)
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform")
|
||||
@patch("vllm_ascend.worker.worker.init_device_properties_triton")
|
||||
def test_init_device(self, mock_init_triton, mock_platform,
|
||||
@patch("torch.npu.set_device")
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch.npu.mem_get_info")
|
||||
def test_init_device(self, mock_mem_get_info, mock_set_device,
|
||||
mock_empty_cache, mock_init_triton,
|
||||
mock_init_dist_env):
|
||||
"""Test _init_device method"""
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
# Setup mock
|
||||
mock_platform.mem_get_info.return_value = (1000, 2000)
|
||||
mock_mem_get_info.return_value = (1000, 2000)
|
||||
|
||||
# Create worker mock
|
||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||
@@ -256,21 +259,13 @@ class TestNPUWorker(TestBase):
|
||||
worker.parallel_config = MagicMock()
|
||||
worker.parallel_config.local_world_size = 0
|
||||
worker.parallel_config.data_parallel_size = 1
|
||||
|
||||
worker.model_config.seed = 42
|
||||
|
||||
# Test _init_device
|
||||
result = worker._init_device()
|
||||
|
||||
# Verify NPUPlatform.set_device is called
|
||||
mock_platform.set_device.assert_called_once()
|
||||
# Verify the parameter passed to set_device is a torch.device object
|
||||
call_args = mock_platform.set_device.call_args[0][0]
|
||||
self.assertEqual(str(call_args), "npu:1")
|
||||
|
||||
mock_platform.empty_cache.assert_called_once()
|
||||
mock_platform.seed_everything.assert_called_once_with(42)
|
||||
mock_platform.mem_get_info.assert_called_once(
|
||||
mock_mem_get_info.assert_called_once(
|
||||
) # Called once in _init_device method
|
||||
mock_init_dist_env.assert_called_once(
|
||||
) # Verify distributed initialization is called
|
||||
@@ -548,9 +543,8 @@ class TestNPUWorker(TestBase):
|
||||
# Verify returns None (empty string is considered false)
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory")
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache")
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info")
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch_npu.npu.memory_stats")
|
||||
@patch("torch_npu.npu.mem_get_info")
|
||||
@patch("vllm_ascend.worker.worker.logger")
|
||||
@@ -559,15 +553,14 @@ class TestNPUWorker(TestBase):
|
||||
mock_logger,
|
||||
mock_torch_mem_get_info,
|
||||
mock_torch_memory_stats,
|
||||
mock_platform_mem_get_info,
|
||||
mock_platform_empty_cache,
|
||||
mock_platform_clear_npu_memory,
|
||||
mock_torch_empty_cache,
|
||||
mock_torch_reset_peak_memory_stats,
|
||||
):
|
||||
"""Test determine_available_memory normal case (no non-torch memory allocation)"""
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
# Setup mock - test case without non-torch memory allocation
|
||||
mock_platform_mem_get_info.side_effect = [
|
||||
mock_torch_mem_get_info.side_effect = [
|
||||
(8000, 10000), # 1st call: before profile execution
|
||||
(7000, 10000), # 2nd call: after profile execution
|
||||
]
|
||||
@@ -606,10 +599,8 @@ class TestNPUWorker(TestBase):
|
||||
result = worker.determine_available_memory()
|
||||
|
||||
# Verify call count and order
|
||||
mock_platform_clear_npu_memory.assert_called_once()
|
||||
self.assertEqual(mock_platform_mem_get_info.call_count, 2)
|
||||
self.assertEqual(mock_torch_mem_get_info.call_count, 4)
|
||||
worker.model_runner.profile_run.assert_called_once()
|
||||
mock_platform_empty_cache.assert_called_once()
|
||||
|
||||
# Verify calculation result with race condition simulation
|
||||
# Calculation logic:
|
||||
@@ -629,24 +620,22 @@ class TestNPUWorker(TestBase):
|
||||
# Verify log output
|
||||
mock_logger.info.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory")
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache")
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info")
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch_npu.npu.memory_stats")
|
||||
@patch("torch_npu.npu.mem_get_info")
|
||||
def test_determine_available_memory_with_non_torch_allocations(
|
||||
self,
|
||||
mock_torch_mem_get_info,
|
||||
mock_torch_memory_stats,
|
||||
mock_platform_mem_get_info,
|
||||
mock_platform_empty_cache,
|
||||
mock_platform_clear_npu_memory,
|
||||
mock_torch_empty_cache,
|
||||
mock_torch_reset_peak_memory_stats,
|
||||
):
|
||||
"""Test determine_available_memory with significant non-torch memory allocation"""
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
# Setup mock - test case with large non-torch memory allocation
|
||||
mock_platform_mem_get_info.side_effect = [
|
||||
mock_torch_mem_get_info.side_effect = [
|
||||
(8000, 10000), # 1st call
|
||||
(7000, 10000), # 2nd call
|
||||
]
|
||||
@@ -695,15 +684,17 @@ class TestNPUWorker(TestBase):
|
||||
expected_result = max(0, int(10000 * 0.9 - 5500))
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory")
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info")
|
||||
@patch("torch.npu.mem_get_info")
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
@patch("torch.npu.empty_cache")
|
||||
def test_determine_available_memory_memory_profiling_error(
|
||||
self, mock_platform_mem_get_info, mock_platform_clear_npu_memory):
|
||||
self, mock_torch_empty_cache, mock_torch_reset_peak_memory_stats,
|
||||
mock_torch_mem_get_info):
|
||||
"""Test determine_available_memory throws exception on memory profiling error"""
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
# Setup mock: initial memory less than current free memory (error case)
|
||||
mock_platform_mem_get_info.side_effect = [
|
||||
mock_torch_mem_get_info.side_effect = [
|
||||
(8000, 10000), # 1st call
|
||||
(9000, 10000), # 2nd call: free memory increased instead
|
||||
]
|
||||
@@ -722,24 +713,22 @@ class TestNPUWorker(TestBase):
|
||||
|
||||
self.assertIn("Error in memory profiling", str(cm.exception))
|
||||
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.clear_npu_memory")
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.empty_cache")
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.mem_get_info")
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch_npu.npu.memory_stats")
|
||||
@patch("torch_npu.npu.mem_get_info")
|
||||
def test_determine_available_memory_negative_result(
|
||||
self,
|
||||
mock_torch_mem_get_info,
|
||||
mock_torch_memory_stats,
|
||||
mock_platform_mem_get_info,
|
||||
mock_platform_empty_cache,
|
||||
mock_platform_clear_npu_memory,
|
||||
mock_torch_empty_cache,
|
||||
mock_torch_reset_peak_memory_stats,
|
||||
):
|
||||
"""Test determine_available_memory returns 0 when result is negative"""
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
# Setup mock: high peak memory causes negative available memory
|
||||
mock_platform_mem_get_info.side_effect = [
|
||||
mock_torch_mem_get_info.side_effect = [
|
||||
(8000, 10000), # 1st call
|
||||
(3000, 10000), # 2nd call
|
||||
]
|
||||
@@ -989,12 +978,10 @@ class TestNPUWorker(TestBase):
|
||||
|
||||
self.assertIn("Sleep mode can only be", str(cm.exception))
|
||||
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.seed_everything")
|
||||
@patch("vllm_ascend.worker.worker.logger")
|
||||
@patch("vllm_ascend.worker.worker.NPUWorker._warm_up_atb")
|
||||
def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
|
||||
mock_logger,
|
||||
mock_seed_everything):
|
||||
mock_logger):
|
||||
"""Test compile_or_warm_up_model method - eager mode"""
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
@@ -1032,17 +1019,13 @@ class TestNPUWorker(TestBase):
|
||||
# Verify log output
|
||||
self.assertEqual(mock_logger.info.call_count, 4)
|
||||
|
||||
# Verify seed setting
|
||||
mock_seed_everything.assert_called_once_with(12345)
|
||||
|
||||
# Verify atb warm up
|
||||
mock_warm_up_atb.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.worker.worker.NPUPlatform.seed_everything")
|
||||
@patch("vllm_ascend.worker.worker.logger")
|
||||
@patch("vllm_ascend.worker.worker.NPUWorker._warm_up_atb")
|
||||
def test_compile_or_warm_up_model_with_graph_capture(
|
||||
self, mock_warm_up_atb, mock_logger, mock_seed_everything):
|
||||
self, mock_warm_up_atb, mock_logger):
|
||||
"""Test compile_or_warm_up_model method - with graph capture enabled"""
|
||||
from vllm_ascend.worker.worker import NPUWorker
|
||||
|
||||
@@ -1072,9 +1055,6 @@ class TestNPUWorker(TestBase):
|
||||
# Should call capture_model in non-eager mode
|
||||
worker.model_runner.capture_model.assert_called_once()
|
||||
|
||||
# Verify seed setting
|
||||
mock_seed_everything.assert_called_once_with(67890)
|
||||
|
||||
# Verify atb warm up
|
||||
mock_warm_up_atb.assert_called_once()
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from torch.fx import GraphModule
|
||||
from vllm.compilation.compiler_interface import CompilerInterface
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import COMPILATION_PASS_KEY
|
||||
|
||||
|
||||
def compile_fx(graph: GraphModule, example_inputs: list,
|
||||
@@ -51,7 +52,7 @@ def fusion_pass_compile(
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
def compile_inner(graph, example_inputs):
|
||||
current_pass_manager = compiler_config["graph_fusion_manager"]
|
||||
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
|
||||
graph = current_pass_manager(graph, runtime_shape)
|
||||
return graph
|
||||
|
||||
|
||||
@@ -25,8 +25,6 @@ import torch
|
||||
from acl.rt import memcpy # type: ignore # noqa: F401
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
def find_loaded_library(lib_name) -> Optional[str]:
|
||||
"""
|
||||
@@ -196,11 +194,10 @@ class CaMemAllocator:
|
||||
handle = data.handle
|
||||
if data.tag in offload_tags:
|
||||
size_in_bytes = handle[1]
|
||||
cpu_backup_tensor = torch.empty(
|
||||
size_in_bytes,
|
||||
dtype=torch.uint8,
|
||||
device='cpu',
|
||||
pin_memory=NPUPlatform.is_pin_memory_available())
|
||||
cpu_backup_tensor = torch.empty(size_in_bytes,
|
||||
dtype=torch.uint8,
|
||||
device='cpu',
|
||||
pin_memory=True)
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
ACL_MEMCPY_DEVICE_TO_HOST = 2
|
||||
dest_max = cpu_ptr + size_in_bytes * 2
|
||||
|
||||
@@ -15,9 +15,8 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import torch
|
||||
@@ -26,18 +25,16 @@ from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
# todo: please remove it when solve cuda hard code in vllm
|
||||
os.environ["VLLM_DISABLE_SHARED_EXPERTS_STREAM"] = "1"
|
||||
# todo: please remove it when support controls garbage collection during CUDA graph capture.
|
||||
os.environ["VLLM_ENABLE_CUDAGRAPH_GC"] = "1"
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.utils import refresh_block_size
|
||||
|
||||
# isort: off
|
||||
from vllm_ascend.utils import (
|
||||
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||
enable_sp, get_ascend_device_type, is_vl_model, update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes, update_default_aclgraph_sizes,
|
||||
check_kv_extra_config)
|
||||
ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD,
|
||||
COMPILATION_PASS_KEY, AscendDeviceType, enable_sp, get_ascend_device_type,
|
||||
is_vl_model, update_aclgraph_sizes, update_cudagraph_capture_sizes,
|
||||
update_default_aclgraph_sizes, check_kv_extra_config)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@@ -47,7 +44,7 @@ else:
|
||||
VllmConfig = None
|
||||
FlexibleArgumentParser = None
|
||||
|
||||
CUSTOM_OP_REGISTERED = False
|
||||
_CUSTOM_OP_REGISTERED = False
|
||||
|
||||
|
||||
class NPUPlatform(Platform):
|
||||
@@ -74,7 +71,7 @@ class NPUPlatform(Platform):
|
||||
It is a parameter of inductor_config used to register custom passes.
|
||||
Currently, we only use Inductor's 'pattern matcher' functionality, so we define our own pass_key.
|
||||
"""
|
||||
return "graph_fusion_manager"
|
||||
return COMPILATION_PASS_KEY
|
||||
|
||||
@classmethod
|
||||
def get_pass_manager_cls(cls) -> str:
|
||||
@@ -131,24 +128,6 @@ class NPUPlatform(Platform):
|
||||
def set_device(cls, device: torch.device):
|
||||
torch.npu.set_device(device)
|
||||
|
||||
@classmethod
|
||||
def empty_cache(cls):
|
||||
torch.npu.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def synchronize(cls):
|
||||
torch.npu.synchronize()
|
||||
|
||||
@classmethod
|
||||
def mem_get_info(cls) -> Tuple[int, int]:
|
||||
return torch.npu.mem_get_info()
|
||||
|
||||
@classmethod
|
||||
def clear_npu_memory(cls):
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
# initialize ascend config from vllm additional_config
|
||||
@@ -351,8 +330,8 @@ class NPUPlatform(Platform):
|
||||
# from vllm_ascend.utils import enable_custom_op
|
||||
# enable_custom_op()
|
||||
# set custom ops path
|
||||
global CUSTOM_OP_REGISTERED
|
||||
if CUSTOM_OP_REGISTERED:
|
||||
global _CUSTOM_OP_REGISTERED
|
||||
if _CUSTOM_OP_REGISTERED:
|
||||
return
|
||||
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
CUSTOM_OPP_PATH = os.path.join(CUR_DIR, "_cann_ops_custom", "vendors",
|
||||
@@ -365,7 +344,7 @@ class NPUPlatform(Platform):
|
||||
"ASCEND_CUSTOM_OPP_PATH"] = f"{CUSTOM_OPP_PATH}:{current_cust_opp_path}"
|
||||
else:
|
||||
os.environ["ASCEND_CUSTOM_OPP_PATH"] = CUSTOM_OPP_PATH
|
||||
CUSTOM_OP_REGISTERED = True
|
||||
_CUSTOM_OP_REGISTERED = True
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, attn_selector_config):
|
||||
|
||||
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
COMPILATION_PASS_KEY = "graph_fusion_manager"
|
||||
ASCEND_QUANTIZATION_METHOD = "ascend"
|
||||
COMPRESSED_TENSORS_METHOD = "compressed-tensors"
|
||||
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#
|
||||
|
||||
import copy
|
||||
import gc
|
||||
from types import NoneType
|
||||
from typing import Optional
|
||||
|
||||
@@ -55,15 +56,19 @@ from vllm_ascend.cpu_binding import bind_cpus
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type,
|
||||
enable_sp, get_ascend_device_type,
|
||||
register_ascend_customop)
|
||||
register_ascend_customop, vllm_version_is)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402
|
||||
from torch._dynamo.variables import TorchInGraphFunctionVariable # noqa: E402
|
||||
|
||||
if vllm_version_is("0.13.0"):
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
else:
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
torch_non_c_binding_in_graph_functions_npu = dict.fromkeys(
|
||||
["torch.npu.current_stream"],
|
||||
TorchInGraphFunctionVariable,
|
||||
@@ -147,7 +152,7 @@ class NPUWorker(WorkerBase):
|
||||
self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
|
||||
free_bytes_before_sleep = torch.npu.mem_get_info()[0]
|
||||
# Save the buffers before level 2 sleep
|
||||
if level == 2:
|
||||
model = self.model_runner.model
|
||||
@@ -157,7 +162,7 @@ class NPUWorker(WorkerBase):
|
||||
}
|
||||
allocator = CaMemAllocator.get_instance()
|
||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||
free_bytes_after_sleep, total = NPUPlatform.mem_get_info()
|
||||
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
used_bytes = total - free_bytes_after_sleep
|
||||
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||||
@@ -210,8 +215,8 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
def _init_device(self):
|
||||
device = torch.device(f"npu:{self.local_rank}")
|
||||
NPUPlatform.set_device(device)
|
||||
NPUPlatform.empty_cache()
|
||||
torch.npu.set_device(device)
|
||||
torch.npu.empty_cache()
|
||||
|
||||
if (self.parallel_config.data_parallel_size > 1
|
||||
and self.parallel_config.data_parallel_size_local > 0
|
||||
@@ -226,11 +231,11 @@ class NPUWorker(WorkerBase):
|
||||
f"be less than or equal to the number of visible devices "
|
||||
f"({visible_device_count}).")
|
||||
|
||||
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||
self.init_npu_memory = torch.npu.mem_get_info()[0]
|
||||
# Initialize the distributed environment.
|
||||
self._init_worker_distributed_environment()
|
||||
# Set random seed.
|
||||
NPUPlatform.seed_everything(self.model_config.seed)
|
||||
set_random_seed(self.model_config.seed)
|
||||
# Initialize device properties used by triton kernels.
|
||||
init_device_properties_triton()
|
||||
return device
|
||||
@@ -258,16 +263,18 @@ class NPUWorker(WorkerBase):
|
||||
def determine_available_memory(self) -> int:
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
NPUPlatform.clear_npu_memory()
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
_, total_npu_memory = NPUPlatform.mem_get_info()
|
||||
_, total_npu_memory = torch.npu.mem_get_info()
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
free_npu_memory, _ = NPUPlatform.mem_get_info()
|
||||
free_npu_memory, _ = torch.npu.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
assert self.init_npu_memory > free_npu_memory, (
|
||||
@@ -280,7 +287,7 @@ class NPUWorker(WorkerBase):
|
||||
peak_memory = torch_npu.npu.memory_stats()["allocated_bytes.all.peak"]
|
||||
# TODO: don`t need impl this func after empty_cache in
|
||||
# Worker.determine_num_available_blocks() unified`
|
||||
NPUPlatform.empty_cache()
|
||||
torch.npu.empty_cache()
|
||||
torch_allocated_bytes = torch_npu.npu.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
total_allocated_bytes = torch_npu.npu.mem_get_info(
|
||||
@@ -389,7 +396,7 @@ class NPUWorker(WorkerBase):
|
||||
self._warm_up_atb()
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
NPUPlatform.seed_everything(self.model_config.seed)
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def _warm_up_atb(self):
|
||||
x = torch.rand((2, 4), dtype=torch.float16).npu()
|
||||
|
||||
Reference in New Issue
Block a user