<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? - Enforce recompute scheduler only in PD-disaggregated mode. - Enforce balance scheduling only in PD-mixed mode. - Enforce fused MC2 only on PD-disaggregated D-side (kv_consumer). <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? By ci <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
865 lines
42 KiB
Python
865 lines
42 KiB
Python
import importlib
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
from vllm.config.compilation import CompilationMode, CUDAGraphMode
|
|
from vllm.platforms import PlatformEnum
|
|
from vllm.v1.attention.selector import AttentionSelectorConfig # type: ignore
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.platform import NPUPlatform
|
|
from vllm_ascend.utils import (
|
|
ASCEND_QUANTIZATION_METHOD,
|
|
COMPRESSED_TENSORS_METHOD,
|
|
AscendDeviceType,
|
|
)
|
|
|
|
|
|
class TestNPUPlatform(TestBase):
|
|
@staticmethod
|
|
def mock_vllm_config():
|
|
mock_vllm_config = MagicMock()
|
|
mock_vllm_config.compilation_config = MagicMock()
|
|
mock_vllm_config.model_config = MagicMock()
|
|
mock_vllm_config.parallel_config = MagicMock()
|
|
mock_vllm_config.cache_config = MagicMock()
|
|
mock_vllm_config.scheduler_config = MagicMock()
|
|
mock_vllm_config.scheduler_config.max_num_seqs = None
|
|
mock_vllm_config.speculative_config = None
|
|
mock_vllm_config.additional_config = {}
|
|
mock_vllm_config.compilation_config.pass_config.enable_sp = False
|
|
mock_vllm_config.compilation_config.cudagraph_mode = None
|
|
return mock_vllm_config
|
|
|
|
@staticmethod
|
|
def mock_vllm_ascend_config():
|
|
mock_ascend_config = MagicMock()
|
|
mock_ascend_config.xlite_graph_config.enabled = False
|
|
mock_ascend_config.xlite_graph_config.full_mode = False
|
|
mock_ascend_config.ascend_compilation_config.enable_npugraph_ex = False
|
|
mock_ascend_config.ascend_fusion_config = None
|
|
mock_ascend_config.recompute_scheduler_enable = False
|
|
mock_ascend_config.SLO_limits_for_dynamic_batch = -1
|
|
mock_ascend_config.enable_shared_expert_dp = False
|
|
mock_ascend_config.update_compile_ranges_split_points = MagicMock()
|
|
return mock_ascend_config
|
|
|
|
def setUp(self):
|
|
self.platform = NPUPlatform()
|
|
self.platform.supported_quantization[:] = ["ascend", "compressed-tensors"]
|
|
|
|
def test_class_variables(self):
|
|
self.assertEqual(NPUPlatform._enum, PlatformEnum.OOT)
|
|
self.assertEqual(NPUPlatform.device_name, "npu")
|
|
self.assertEqual(NPUPlatform.device_type, "npu")
|
|
self.assertEqual(NPUPlatform.simple_compile_backend, "eager")
|
|
self.assertEqual(NPUPlatform.ray_device_key, "NPU")
|
|
self.assertEqual(NPUPlatform.device_control_env_var, "ASCEND_RT_VISIBLE_DEVICES")
|
|
self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1")
|
|
self.assertEqual(NPUPlatform.supported_quantization, [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD])
|
|
|
|
def test_is_sleep_mode_available(self):
|
|
self.assertTrue(self.platform.is_sleep_mode_available())
|
|
|
|
@patch("vllm_ascend.utils.adapt_patch")
|
|
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
|
def test_pre_register_and_update_with_parser(self, mock_quant_config,
|
|
mock_adapt_patch):
|
|
mock_parser = MagicMock()
|
|
mock_action = MagicMock()
|
|
mock_action.choices = ["awq", "gptq"]
|
|
mock_parser._option_string_actions = {"--quantization": mock_action}
|
|
|
|
self.platform.pre_register_and_update(mock_parser)
|
|
|
|
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
|
|
|
self.assertTrue(ASCEND_QUANTIZATION_METHOD in mock_action.choices)
|
|
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
|
|
|
|
@patch("vllm_ascend.utils.adapt_patch")
|
|
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
|
def test_pre_register_and_update_without_parser(self, mock_quant_config,
|
|
mock_adapt_patch):
|
|
self.platform.pre_register_and_update(None)
|
|
|
|
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
|
|
|
@patch("vllm_ascend.utils.adapt_patch")
|
|
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
|
def test_pre_register_and_update_with_parser_no_quant_action(
|
|
self, mock_quant_config, mock_adapt_patch):
|
|
mock_parser = MagicMock()
|
|
mock_parser._option_string_actions = {}
|
|
|
|
self.platform.pre_register_and_update(mock_parser)
|
|
|
|
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
|
|
|
@patch("vllm_ascend.utils.adapt_patch")
|
|
@patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig")
|
|
def test_pre_register_and_update_with_existing_ascend_quant(
|
|
self, mock_quant_config, mock_adapt_patch):
|
|
mock_parser = MagicMock()
|
|
mock_action = MagicMock()
|
|
mock_action.choices = ["awq", ASCEND_QUANTIZATION_METHOD]
|
|
mock_parser._option_string_actions = {"--quantization": mock_action}
|
|
|
|
self.platform.pre_register_and_update(mock_parser)
|
|
|
|
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
|
self.assertEqual(len(mock_action.choices), 2)
|
|
|
|
def test_apply_config_platform_defaults_sets_ascend_default_max(self):
|
|
test_cases = [
|
|
(40, 3, 160),
|
|
(200, 3, 512),
|
|
]
|
|
|
|
for max_num_seqs, num_speculative_tokens, expected_max in test_cases:
|
|
with self.subTest(
|
|
max_num_seqs=max_num_seqs,
|
|
num_speculative_tokens=num_speculative_tokens,
|
|
expected_max=expected_max,
|
|
):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.scheduler_config.max_num_seqs = max_num_seqs
|
|
vllm_config.speculative_config = MagicMock(
|
|
num_speculative_tokens=num_speculative_tokens
|
|
)
|
|
vllm_config.compilation_config.max_cudagraph_capture_size = None
|
|
vllm_config.compilation_config.cudagraph_capture_sizes = None
|
|
|
|
self.platform.apply_config_platform_defaults(vllm_config)
|
|
|
|
self.assertEqual(
|
|
vllm_config.compilation_config.max_cudagraph_capture_size,
|
|
expected_max,
|
|
)
|
|
|
|
def test_apply_config_platform_defaults_respects_explicit_max(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.compilation_config.max_cudagraph_capture_size = 456
|
|
vllm_config.compilation_config.cudagraph_capture_sizes = None
|
|
|
|
self.platform.apply_config_platform_defaults(vllm_config)
|
|
|
|
self.assertEqual(vllm_config.compilation_config.max_cudagraph_capture_size, 456)
|
|
|
|
def test_apply_config_platform_defaults_respects_explicit_sizes(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.compilation_config.max_cudagraph_capture_size = None
|
|
vllm_config.compilation_config.cudagraph_capture_sizes = [1, 2, 4]
|
|
|
|
self.platform.apply_config_platform_defaults(vllm_config)
|
|
|
|
self.assertIsNone(vllm_config.compilation_config.max_cudagraph_capture_size)
|
|
self.assertEqual(vllm_config.compilation_config.cudagraph_capture_sizes, [1, 2, 4])
|
|
|
|
def test_apply_config_platform_defaults_skips_when_scheduler_max_num_seqs_is_missing(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.compilation_config.max_cudagraph_capture_size = None
|
|
vllm_config.compilation_config.cudagraph_capture_sizes = None
|
|
|
|
self.platform.apply_config_platform_defaults(vllm_config)
|
|
|
|
self.assertIsNone(vllm_config.compilation_config.max_cudagraph_capture_size)
|
|
|
|
@patch("vllm_ascend.platform.refresh_block_size")
|
|
@patch("vllm_ascend.platform.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.platform.enable_sp", return_value=False)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
def test_check_and_update_config_preserves_platform_default_max_input(
|
|
self,
|
|
mock_auto_detect,
|
|
mock_init_ascend,
|
|
_mock_enable_sp,
|
|
_mock_device_type,
|
|
_mock_refresh_block_size,
|
|
):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.scheduler_config.max_num_seqs = 77
|
|
vllm_config.compilation_config.max_cudagraph_capture_size = None
|
|
vllm_config.compilation_config.cudagraph_capture_sizes = None
|
|
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
|
vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
|
vllm_config.compilation_config.custom_ops = []
|
|
vllm_config.model_config.enforce_eager = False
|
|
vllm_config.model_config.enable_sleep_mode = True
|
|
vllm_config.model_config.is_encoder_decoder = False
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.parallel_config.worker_cls = "manual"
|
|
vllm_config.parallel_config.cp_kv_cache_interleave_size = 1
|
|
vllm_config.cache_config.block_size = 1
|
|
|
|
self.platform.apply_config_platform_defaults(vllm_config)
|
|
|
|
observed_inputs: list[int | None] = []
|
|
vllm_config._set_cudagraph_sizes = MagicMock(
|
|
side_effect=lambda: observed_inputs.append(
|
|
vllm_config.compilation_config.max_cudagraph_capture_size
|
|
)
|
|
)
|
|
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
self.assertEqual(observed_inputs, [77])
|
|
|
|
def test_get_device_capability(self):
|
|
self.assertIsNone(self.platform.get_device_capability(device_id=0))
|
|
|
|
@patch("torch.npu.get_device_name")
|
|
def test_get_device_name(self, mock_get_device_name):
|
|
device_id = 0
|
|
device_name = "Ascend910B2"
|
|
mock_get_device_name.return_value = device_name
|
|
self.assertEqual(self.platform.get_device_name(device_id), device_name)
|
|
mock_get_device_name.assert_called_once_with(0)
|
|
|
|
@patch("torch.npu.get_device_properties")
|
|
def test_get_device_uuid(self, mock_get_device_properties):
|
|
device_id = 0
|
|
device_properties = MagicMock()
|
|
device_properties.uuid = "01020304-0000-0000-0000-01020304"
|
|
mock_get_device_properties.return_value = device_properties
|
|
self.assertEqual(self.platform.get_device_uuid(device_id), device_properties.uuid)
|
|
mock_get_device_properties.assert_called_once_with(0)
|
|
|
|
@patch("torch.inference_mode")
|
|
def test_inference_mode(self, mock_inference_mode):
|
|
mock_inference_mode.return_value = None
|
|
self.assertIsNone(self.platform.inference_mode())
|
|
mock_inference_mode.assert_called_once()
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("os.environ", {})
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_basic_config_update(
|
|
self, mock_init_recompute, mock_soc_version, mock_update_acl, mock_init_ascend, mock_auto_detect
|
|
):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.parallel_config.enable_expert_parallel = False
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
mock_init_recompute.return_value = MagicMock()
|
|
vllm_config.scheduler_config = MagicMock()
|
|
|
|
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
|
|
# Without this reload, when calling self.platform.check_and_update_config,
|
|
# it would execute the original unmocked init_ascend_config method, causing the unit test to fail.
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
mock_init_ascend.assert_called_once_with(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_no_model_config_warning(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.model_config = None
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
mock_init_recompute.return_value = MagicMock()
|
|
vllm_config.scheduler_config = MagicMock()
|
|
|
|
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch.object(platform.NPUPlatform, "_fix_incompatible_config"):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
self.assertTrue("Model config is missing" in cm.output[0])
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_enforce_eager_mode(self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.model_config.enforce_eager = True
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
mock_init_recompute.return_value = MagicMock()
|
|
vllm_config.scheduler_config = MagicMock()
|
|
|
|
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch.object(platform.NPUPlatform, "_fix_incompatible_config"):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
self.assertTrue("Compilation disabled, using eager mode by default" in cm.output[0])
|
|
|
|
self.assertEqual(
|
|
vllm_config.compilation_config.mode,
|
|
CompilationMode.NONE,
|
|
)
|
|
|
|
self.assertEqual(
|
|
vllm_config.compilation_config.cudagraph_mode,
|
|
CUDAGraphMode.NONE,
|
|
)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_unsupported_compilation_level(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.model_config.enforce_eager = False
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
mock_init_recompute.return_value = MagicMock()
|
|
vllm_config.scheduler_config = MagicMock()
|
|
|
|
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
|
|
|
|
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch.object(platform.NPUPlatform, "_fix_incompatible_config"):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
self.assertTrue("NPU does not support" in cm.output[0])
|
|
|
|
self.assertEqual(
|
|
vllm_config.compilation_config.mode,
|
|
CompilationMode.NONE,
|
|
)
|
|
self.assertEqual(
|
|
vllm_config.compilation_config.cudagraph_mode,
|
|
CUDAGraphMode.NONE,
|
|
)
|
|
|
|
@pytest.mark.skip("Revert me when vllm support setting cudagraph_mode on oot platform")
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
def test_check_and_update_config_unsupported_cudagraph_mode(self, mock_init_ascend, mock_soc_version, mock_auto_detect):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.model_config.enforce_eager = False
|
|
vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.FULL
|
|
|
|
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.check_and_update_config(vllm_config)
|
|
self.assertTrue("cudagraph_mode is not support on NPU. falling back to NONE" in cm.output[0])
|
|
|
|
self.assertEqual(
|
|
vllm_config.compilation_config.mode,
|
|
CompilationMode.NONE,
|
|
)
|
|
self.assertEqual(
|
|
vllm_config.compilation_config.cudagraph_mode,
|
|
CUDAGraphMode.NONE,
|
|
)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_cache_config_block_size(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.cache_config.block_size = None
|
|
vllm_config.cache_config.enable_prefix_caching = True
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
mock_init_recompute.return_value = MagicMock()
|
|
vllm_config.scheduler_config = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
self.assertEqual(vllm_config.cache_config.block_size, 128)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_recompute_scheduler_rejects_pd_mixed_no_kv_transfer(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
mock_ascend_config.recompute_scheduler_enable = True
|
|
mock_init_ascend.return_value = mock_ascend_config
|
|
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.kv_transfer_config = None
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.scheduler_config = MagicMock()
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with pytest.raises(ValueError, match=r"recompute_scheduler_enable.*PD-disaggregated.*PD-mixed"):
|
|
with patch.object(platform.NPUPlatform, "_fix_incompatible_config"):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_recompute_scheduler_rejects_pd_mixed_kv_both(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
mock_ascend_config.recompute_scheduler_enable = True
|
|
mock_init_ascend.return_value = mock_ascend_config
|
|
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.kv_transfer_config = MagicMock(kv_role="kv_both", engine_id="engine0")
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.scheduler_config = MagicMock()
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with pytest.raises(ValueError, match=r"recompute_scheduler_enable.*PD-disaggregated.*PD-mixed"):
|
|
with (
|
|
patch.object(platform.NPUPlatform, "_fix_incompatible_config"),
|
|
patch.object(platform, "check_kv_extra_config"),
|
|
):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_balance_scheduler_rejects_pd_disaggregated_kv_producer(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
mock_ascend_config.recompute_scheduler_enable = False
|
|
mock_init_ascend.return_value = mock_ascend_config
|
|
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.kv_transfer_config = MagicMock(kv_role="kv_producer", engine_id="engine0")
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.scheduler_config = MagicMock()
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_BALANCE_SCHEDULING", True, create=True):
|
|
with pytest.raises(ValueError, match=r"VLLM_ASCEND_BALANCE_SCHEDULING.*PD-mixed.*PD-disaggregated"):
|
|
with (
|
|
patch.object(platform.NPUPlatform, "_fix_incompatible_config"),
|
|
patch.object(platform, "check_kv_extra_config"),
|
|
):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_balance_scheduler_rejects_pd_disaggregated_kv_consumer(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
mock_ascend_config.recompute_scheduler_enable = False
|
|
mock_init_ascend.return_value = mock_ascend_config
|
|
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.kv_transfer_config = MagicMock(kv_role="kv_consumer", engine_id="engine0")
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.scheduler_config = MagicMock()
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_BALANCE_SCHEDULING", True, create=True):
|
|
with pytest.raises(ValueError, match=r"VLLM_ASCEND_BALANCE_SCHEDULING.*PD-mixed.*PD-disaggregated"):
|
|
with (
|
|
patch.object(platform.NPUPlatform, "_fix_incompatible_config"),
|
|
patch.object(platform, "check_kv_extra_config"),
|
|
):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_fused_mc2_rejects_pd_mixed_no_kv_transfer(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
mock_ascend_config.recompute_scheduler_enable = False
|
|
mock_ascend_config.enable_mc2_hierarchy_comm = False
|
|
mock_init_ascend.return_value = mock_ascend_config
|
|
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.kv_transfer_config = None
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.scheduler_config = MagicMock()
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2", 1, create=True):
|
|
with pytest.raises(ValueError, match=r"VLLM_ASCEND_ENABLE_FUSED_MC2.*kv_role='kv_consumer'.*PD-mixed"):
|
|
with patch.object(platform.NPUPlatform, "_fix_incompatible_config"):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_fused_mc2_rejects_pd_mixed_kv_both(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
mock_ascend_config.recompute_scheduler_enable = False
|
|
mock_ascend_config.enable_mc2_hierarchy_comm = False
|
|
mock_init_ascend.return_value = mock_ascend_config
|
|
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.kv_transfer_config = MagicMock(kv_role="kv_both", engine_id="engine0")
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.scheduler_config = MagicMock()
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2", 1, create=True):
|
|
with pytest.raises(ValueError, match=r"VLLM_ASCEND_ENABLE_FUSED_MC2.*kv_role='kv_consumer'.*kv_role='kv_both'"):
|
|
with (
|
|
patch.object(platform.NPUPlatform, "_fix_incompatible_config"),
|
|
patch.object(platform, "check_kv_extra_config"),
|
|
):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_fused_mc2_rejects_pd_disaggregated_kv_producer(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
mock_ascend_config.recompute_scheduler_enable = False
|
|
mock_ascend_config.enable_mc2_hierarchy_comm = False
|
|
mock_init_ascend.return_value = mock_ascend_config
|
|
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.kv_transfer_config = MagicMock(kv_role="kv_producer", engine_id="engine0")
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.scheduler_config = MagicMock()
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2", 1, create=True):
|
|
with pytest.raises(ValueError, match=r"VLLM_ASCEND_ENABLE_FUSED_MC2.*kv_role='kv_consumer'.*kv_role='kv_producer'"):
|
|
with (
|
|
patch.object(platform.NPUPlatform, "_fix_incompatible_config"),
|
|
patch.object(platform, "check_kv_extra_config"),
|
|
):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_fused_mc2_allows_pd_disaggregated_kv_consumer(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
mock_ascend_config.recompute_scheduler_enable = False
|
|
mock_ascend_config.enable_mc2_hierarchy_comm = False
|
|
mock_init_ascend.return_value = mock_ascend_config
|
|
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.kv_transfer_config = MagicMock(kv_role="kv_consumer", engine_id="engine0")
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
vllm_config.scheduler_config = MagicMock()
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform = platform.NPUPlatform()
|
|
|
|
with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2", 1, create=True):
|
|
with (
|
|
patch.object(platform.NPUPlatform, "_fix_incompatible_config"),
|
|
patch.object(platform, "check_kv_extra_config"),
|
|
):
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
def test_update_block_size_for_backend_preserves_hybrid_block_size(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.model_config.is_hybrid = True
|
|
vllm_config.cache_config.block_size = 1024
|
|
vllm_config.cache_config.user_specified_block_size = False
|
|
|
|
self.platform.update_block_size_for_backend(vllm_config)
|
|
|
|
self.assertEqual(vllm_config.cache_config.block_size, 1024)
|
|
|
|
def test_update_block_size_for_backend_preserves_user_block_size(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.model_config.is_hybrid = False
|
|
vllm_config.cache_config.block_size = 512
|
|
vllm_config.cache_config.user_specified_block_size = True
|
|
|
|
self.platform.update_block_size_for_backend(vllm_config)
|
|
|
|
self.assertEqual(vllm_config.cache_config.block_size, 512)
|
|
|
|
def test_validate_layer_sharding_config_accepts_single_node(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]}
|
|
vllm_config.kv_transfer_config = None
|
|
|
|
self.platform._validate_layer_sharding_config(vllm_config)
|
|
|
|
def test_validate_layer_sharding_config_accepts_kv_producer(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]}
|
|
vllm_config.kv_transfer_config = MagicMock(is_kv_producer=True, kv_role="kv_producer")
|
|
|
|
self.platform._validate_layer_sharding_config(vllm_config)
|
|
|
|
def test_validate_layer_sharding_config_rejects_non_kv_producer(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]}
|
|
vllm_config.kv_transfer_config = MagicMock(is_kv_producer=False, kv_role="kv_consumer")
|
|
|
|
with pytest.raises(ValueError, match="layer_sharding is only supported on P nodes"):
|
|
self.platform._validate_layer_sharding_config(vllm_config)
|
|
|
|
def test_validate_layer_sharding_config_rejects_kv_both(self):
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]}
|
|
vllm_config.kv_transfer_config = MagicMock(is_kv_producer=True, kv_role="kv_both")
|
|
|
|
with pytest.raises(ValueError, match="layer_sharding is only supported on P nodes"):
|
|
self.platform._validate_layer_sharding_config(vllm_config)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_v1_worker_class_selection(
|
|
self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect
|
|
):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.parallel_config.worker_cls = "auto"
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
mock_init_recompute.return_value = MagicMock()
|
|
vllm_config.scheduler_config = MagicMock()
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.check_and_update_config(vllm_config)
|
|
|
|
self.assertEqual(
|
|
vllm_config.parallel_config.worker_cls,
|
|
"vllm_ascend.worker.worker.NPUWorker",
|
|
)
|
|
|
|
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
|
test_ascend_config.xlite_graph_config.enabled = True
|
|
mock_init_ascend.return_value = test_ascend_config
|
|
vllm_config.parallel_config.worker_cls = "auto"
|
|
self.platform.check_and_update_config(vllm_config)
|
|
self.assertEqual(
|
|
vllm_config.parallel_config.worker_cls,
|
|
"vllm_ascend.xlite.xlite_worker.XliteWorker",
|
|
)
|
|
|
|
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType._310P)
|
|
@patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config")
|
|
def test_check_and_update_config_310p_no_custom_ops(self, mock_init_recompute, mock_soc_version, mock_init_ascend, mock_auto_detect):
|
|
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config()
|
|
vllm_config = TestNPUPlatform.mock_vllm_config()
|
|
vllm_config.compilation_config.custom_ops = []
|
|
vllm_config.parallel_config.decode_context_parallel_size = 1
|
|
vllm_config.parallel_config.prefill_context_parallel_size = 1
|
|
vllm_config.parallel_config.tensor_parallel_size = 1
|
|
mock_init_recompute.return_value = MagicMock()
|
|
|
|
vllm_config.scheduler_config = MagicMock()
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
|
|
self.platform.check_and_update_config(vllm_config)
|
|
self.assertEqual(vllm_config.compilation_config.custom_ops, [])
|
|
|
|
def test_get_attn_backend_cls_use_v1_and_mla(self):
|
|
attn_selector_config = AttentionSelectorConfig(
|
|
dtype=torch.float16,
|
|
head_size=0,
|
|
kv_cache_dtype=None,
|
|
block_size=128,
|
|
use_mla=True,
|
|
use_sparse=False,
|
|
)
|
|
result = self.platform.get_attn_backend_cls("ascend", attn_selector_config)
|
|
self.assertEqual(result, "vllm_ascend.attention.mla_v1.AscendMLABackend")
|
|
|
|
def test_get_attn_backend_cls_use_v1_only(self):
|
|
attn_selector_config = AttentionSelectorConfig(
|
|
dtype=torch.float16,
|
|
head_size=0,
|
|
kv_cache_dtype=None,
|
|
block_size=128,
|
|
use_mla=False,
|
|
use_sparse=False,
|
|
)
|
|
result = self.platform.get_attn_backend_cls("ascend", attn_selector_config)
|
|
self.assertEqual(result, "vllm_ascend.attention.attention_v1.AscendAttentionBackend")
|
|
|
|
def test_get_punica_wrapper(self):
|
|
result = self.platform.get_punica_wrapper()
|
|
|
|
self.assertEqual(result, "vllm_ascend.lora.punica_npu.PunicaWrapperNPU")
|
|
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
@patch("torch.npu.max_memory_allocated")
|
|
def test_get_current_memory_usage_with_specific_device(self, mock_max_memory, mock_reset_stats):
|
|
max_memory_allocated_result = 1024.0
|
|
mock_max_memory.return_value = max_memory_allocated_result
|
|
test_device = torch.device("npu:0")
|
|
result = self.platform.get_current_memory_usage(device=test_device)
|
|
|
|
mock_reset_stats.assert_called_once_with(test_device)
|
|
mock_max_memory.assert_called_once_with(test_device)
|
|
self.assertEqual(result, max_memory_allocated_result)
|
|
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
@patch("torch.npu.max_memory_allocated")
|
|
def test_get_current_memory_usage_with_default_device(self, mock_max_memory, mock_reset_stats):
|
|
max_memory_allocated_result = 1024.0
|
|
mock_max_memory.return_value = max_memory_allocated_result
|
|
|
|
result = self.platform.get_current_memory_usage()
|
|
|
|
mock_reset_stats.assert_called_once_with(None)
|
|
mock_max_memory.assert_called_once_with(None)
|
|
self.assertEqual(result, max_memory_allocated_result)
|
|
|
|
@patch("torch.npu.reset_peak_memory_stats", side_effect=RuntimeError("Device error"))
|
|
@patch("torch.npu.max_memory_allocated")
|
|
def test_get_current_memory_usage_when_reset_stats_fails(self, mock_max_memory, mock_reset_stats):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.get_current_memory_usage()
|
|
mock_reset_stats.assert_called_once()
|
|
mock_max_memory.assert_not_called()
|
|
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
@patch(
|
|
"torch.npu.max_memory_allocated",
|
|
side_effect=RuntimeError("Memory query failed"),
|
|
)
|
|
def test_get_current_memory_usage_when_query_fails(self, mock_max_memory, mock_reset_stats):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.get_current_memory_usage()
|
|
mock_reset_stats.assert_called_once()
|
|
mock_max_memory.assert_called_once()
|
|
|
|
def test_get_device_communicator_cls_returns_correct_value(self):
|
|
self.assertEqual(
|
|
self.platform.get_device_communicator_cls(),
|
|
"vllm_ascend.distributed.device_communicators.npu_communicator.NPUCommunicator",
|
|
)
|
|
|
|
def test_is_pin_memory_available_returns_true(self):
|
|
self.assertTrue(self.platform.is_pin_memory_available())
|
|
|
|
def test_get_static_graph_wrapper_cls_returns_correct_value(self):
|
|
self.assertEqual(
|
|
self.platform.get_static_graph_wrapper_cls(),
|
|
"vllm_ascend.compilation.acl_graph.ACLGraphWrapper",
|
|
)
|