### What this PR does / why we need it?
Remove ETP/EP maintained in branch main. We drop this as there is no
relevant scenarios to use ETP now, and we may subsequently advocate
implementing expert tensor parallelism in vLLM to support scenarios
where the expert is needed to be sliced
This is a part of #1422 backport.
Fixes https://github.com/vllm-project/vllm-ascend/issues/1396
https://github.com/vllm-project/vllm-ascend/issues/1154
### Does this PR introduce _any_ user-facing change?
We'll not maintain etp/ep in vllm-ascend anymore, and use the tp/ep in
vllm instead.
### How was this patch tested?
CI passed with new added and existing test.
- vLLM version: v0.9.2
- vLLM main:
fe8a2c544a
Signed-off-by: MengqingCao <cmq0113@163.com>
589 lines
25 KiB
Python
589 lines
25 KiB
Python
import importlib
|
|
import unittest
|
|
from datetime import timedelta
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from torch.distributed import ProcessGroup
|
|
from torch.distributed.distributed_c10d import PrefixStore
|
|
from vllm.config import CompilationLevel
|
|
from vllm.platforms import PlatformEnum
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.platform import NPUPlatform
|
|
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD
|
|
|
|
|
|
class TestNPUPlatform(TestBase):
|
|
|
|
def setUp(self):
|
|
self.platform = NPUPlatform()
|
|
|
|
self.mock_vllm_config = MagicMock()
|
|
self.mock_vllm_config.compilation_config = MagicMock()
|
|
self.mock_vllm_config.model_config = MagicMock()
|
|
self.mock_vllm_config.parallel_config = MagicMock()
|
|
self.mock_vllm_config.cache_config = MagicMock()
|
|
self.mock_vllm_config.scheduler_config = MagicMock()
|
|
self.mock_vllm_config.speculative_config = None
|
|
|
|
self.mock_ascend_config = MagicMock()
|
|
self.mock_ascend_config.torchair_graph_config.enabled = False
|
|
self.mock_ascend_config.ascend_scheduler_config.enabled = False
|
|
|
|
def test_class_variables(self):
|
|
self.assertEqual(NPUPlatform._enum, PlatformEnum.OOT)
|
|
self.assertEqual(NPUPlatform.device_name, "npu")
|
|
self.assertEqual(NPUPlatform.device_type, "npu")
|
|
self.assertEqual(NPUPlatform.simple_compile_backend, "eager")
|
|
self.assertEqual(NPUPlatform.ray_device_key, "NPU")
|
|
self.assertEqual(NPUPlatform.device_control_env_var,
|
|
"ASCEND_RT_VISIBLE_DEVICES")
|
|
self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1")
|
|
self.assertEqual(NPUPlatform.supported_quantization,
|
|
[ASCEND_QUATIZATION_METHOD])
|
|
|
|
def test_is_sleep_mode_available(self):
|
|
self.assertTrue(self.platform.is_sleep_mode_available())
|
|
|
|
@patch("vllm_ascend.utils.adapt_patch")
|
|
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
|
def test_pre_register_and_update_with_parser(self, mock_quant_config,
|
|
mock_adapt_patch):
|
|
mock_parser = MagicMock()
|
|
mock_action = MagicMock()
|
|
mock_action.choices = ["awq", "gptq"]
|
|
mock_parser._option_string_actions = {"--quantization": mock_action}
|
|
|
|
self.platform.pre_register_and_update(mock_parser)
|
|
|
|
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
|
|
|
self.assertTrue(ASCEND_QUATIZATION_METHOD in mock_action.choices)
|
|
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
|
|
|
|
@patch("vllm_ascend.utils.adapt_patch")
|
|
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
|
def test_pre_register_and_update_without_parser(self, mock_quant_config,
|
|
mock_adapt_patch):
|
|
self.platform.pre_register_and_update(None)
|
|
|
|
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
|
|
|
@patch("vllm_ascend.utils.adapt_patch")
|
|
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
|
def test_pre_register_and_update_with_parser_no_quant_action(
|
|
self, mock_quant_config, mock_adapt_patch):
|
|
mock_parser = MagicMock()
|
|
mock_parser._option_string_actions = {}
|
|
|
|
self.platform.pre_register_and_update(mock_parser)
|
|
|
|
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
|
|
|
@patch("vllm_ascend.utils.adapt_patch")
|
|
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
|
def test_pre_register_and_update_with_existing_ascend_quant(
|
|
self, mock_quant_config, mock_adapt_patch):
|
|
mock_parser = MagicMock()
|
|
mock_action = MagicMock()
|
|
mock_action.choices = ["awq", ASCEND_QUATIZATION_METHOD]
|
|
mock_parser._option_string_actions = {"--quantization": mock_action}
|
|
|
|
self.platform.pre_register_and_update(mock_parser)
|
|
|
|
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
|
self.assertEqual(len(mock_action.choices), 2)
|
|
|
|
def test_get_device_capability(self):
|
|
self.assertIsNone(self.platform.get_device_capability(device_id=0))
|
|
|
|
@patch("torch.npu.get_device_name")
|
|
def test_get_device_name(self, mock_get_device_name):
|
|
device_id = 0
|
|
device_name = "Ascend910B2"
|
|
mock_get_device_name.return_value = device_name
|
|
self.assertEqual(self.platform.get_device_name(device_id), device_name)
|
|
mock_get_device_name.assert_called_once_with(0)
|
|
|
|
def test_is_async_output_supported(self):
|
|
self.assertTrue(
|
|
self.platform.is_async_output_supported(enforce_eager=None))
|
|
self.assertTrue(
|
|
self.platform.is_async_output_supported(enforce_eager=True))
|
|
self.assertTrue(
|
|
self.platform.is_async_output_supported(enforce_eager=False))
|
|
|
|
@patch("torch.inference_mode")
|
|
def test_inference_mode(self, mock_inference_mode):
|
|
mock_inference_mode.return_value = None
|
|
self.assertIsNone(self.platform.inference_mode())
|
|
mock_inference_mode.assert_called_once()
|
|
|
|
@patch("torch.npu.set_device")
|
|
def test_set_device_normal(self, mock_set_device):
|
|
device = torch.device("npu:0")
|
|
self.platform.set_device(device)
|
|
mock_set_device.assert_called_once_with(device)
|
|
|
|
@patch("torch.npu.set_device",
|
|
side_effect=RuntimeError("Device not available"))
|
|
def test_set_device_failure(self, mock_set_device):
|
|
device = torch.device("npu:0")
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.set_device(device)
|
|
mock_set_device.assert_called_once_with(device)
|
|
|
|
@patch("torch.npu.empty_cache")
|
|
def test_empty_cache_normal(self, mock_empty_cache):
|
|
self.platform.empty_cache()
|
|
mock_empty_cache.assert_called_once()
|
|
|
|
@patch("torch.npu.empty_cache",
|
|
side_effect=RuntimeError("Cache clearing failed"))
|
|
def test_empty_cache_failure(self, mock_empty_cache):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.empty_cache()
|
|
mock_empty_cache.assert_called_once()
|
|
|
|
@patch("torch.npu.synchronize")
|
|
def test_synchronize_normal(self, mock_synchronize):
|
|
self.platform.synchronize()
|
|
mock_synchronize.assert_called_once()
|
|
|
|
@patch("torch.npu.synchronize",
|
|
side_effect=RuntimeError("Synchronization failed"))
|
|
def test_synchronize_failure(self, mock_synchronize):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.synchronize()
|
|
mock_synchronize.assert_called_once()
|
|
|
|
@patch("torch.npu.mem_get_info")
|
|
def test_mem_get_info_normal(self, mock_mem_get_info):
|
|
free_memory_size = 1024
|
|
total_memory_size = 2048
|
|
memory_info = (free_memory_size, total_memory_size)
|
|
mock_mem_get_info.return_value = memory_info
|
|
result = self.platform.mem_get_info()
|
|
self.assertIsInstance(result, tuple)
|
|
self.assertEqual(len(result), 2)
|
|
self.assertEqual(result, memory_info)
|
|
mock_mem_get_info.assert_called_once()
|
|
|
|
@patch("torch.npu.mem_get_info",
|
|
side_effect=RuntimeError("NPU not available"))
|
|
def test_mem_get_info_failure(self, mock_mem_get_info):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.mem_get_info()
|
|
mock_mem_get_info.assert_called_once()
|
|
|
|
@patch("gc.collect")
|
|
@patch("torch.npu.empty_cache")
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
def test_clear_npu_memory_normal(self, mock_reset_stats, mock_empty_cache,
|
|
mock_gc_collect):
|
|
self.platform.clear_npu_memory()
|
|
|
|
mock_gc_collect.assert_called_once()
|
|
mock_empty_cache.assert_called_once()
|
|
mock_reset_stats.assert_called_once()
|
|
|
|
@patch("gc.collect", side_effect=Exception("GC failed"))
|
|
@patch("torch.npu.empty_cache")
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
def test_clear_npu_memory_gc_collect_failure(self, mock_reset_stats,
|
|
mock_empty_cache,
|
|
mock_gc_collect):
|
|
with self.assertRaises(Exception):
|
|
self.platform.clear_npu_memory()
|
|
|
|
mock_gc_collect.assert_called_once()
|
|
mock_empty_cache.assert_not_called()
|
|
mock_reset_stats.assert_not_called()
|
|
|
|
@patch("gc.collect")
|
|
@patch("torch.npu.empty_cache",
|
|
side_effect=RuntimeError("Cache clear failed"))
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
def test_clear_npu_memory_empty_cache_failure(self, mock_reset_stats,
|
|
mock_empty_cache,
|
|
mock_gc_collect):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.clear_npu_memory()
|
|
|
|
mock_gc_collect.assert_called_once()
|
|
mock_empty_cache.assert_called_once()
|
|
mock_reset_stats.assert_not_called()
|
|
|
|
@patch("gc.collect")
|
|
@patch("torch.npu.empty_cache")
|
|
@patch("torch.npu.reset_peak_memory_stats",
|
|
side_effect=RuntimeError("Reset failed"))
|
|
def test_clear_npu_memory_reset_stats_failure(self, mock_reset_stats,
|
|
mock_empty_cache,
|
|
mock_gc_collect):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.clear_npu_memory()
|
|
|
|
mock_gc_collect.assert_called_once()
|
|
mock_empty_cache.assert_called_once()
|
|
mock_reset_stats.assert_called_once()
|
|
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
|
@patch("os.environ", {})
|
|
def test_check_and_update_config_basic_config_update(
|
|
self, mock_is_310p, mock_update_acl, mock_init_ascend,
|
|
mock_check_ascend):
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
self.mock_vllm_config.parallel_config.enable_expert_parallel = False
|
|
|
|
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
|
|
# Without this reload, when calling self.platform.check_and_update_config,
|
|
# it would execute the original unmocked init_ascend_config method, causing the unit test to fail.
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
|
|
mock_init_ascend.assert_called_once_with(self.mock_vllm_config)
|
|
mock_check_ascend.assert_called_once()
|
|
|
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
def test_check_and_update_config_no_model_config_warning(
|
|
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
self.mock_vllm_config.model_config = None
|
|
|
|
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
self.assertTrue("Model config is missing" in cm.output[0])
|
|
|
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
def test_check_and_update_config_enforce_eager_mode(
|
|
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
self.mock_vllm_config.model_config.enforce_eager = True
|
|
|
|
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
self.assertTrue("Compilation disabled, using eager mode by default" in
|
|
cm.output[0])
|
|
self.assertEqual(
|
|
self.mock_vllm_config.compilation_config.level,
|
|
CompilationLevel.NO_COMPILATION,
|
|
)
|
|
|
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
def test_check_and_update_config_unsupported_compilation_level(
|
|
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
self.mock_vllm_config.model_config.enforce_eager = False
|
|
self.mock_vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
|
|
|
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
self.assertTrue("NPU does not support" in cm.output[0])
|
|
self.assertEqual(
|
|
self.mock_vllm_config.compilation_config.level,
|
|
CompilationLevel.NO_COMPILATION,
|
|
)
|
|
|
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
def test_check_and_update_config_torchair_enabled_compilation(
|
|
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
|
self.mock_ascend_config.torchair_graph_config.enabled = True
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
self.mock_vllm_config.model_config.enforce_eager = False
|
|
self.mock_vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
|
|
|
|
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
self.assertTrue("Torchair compilation enabled" in cm.output[0])
|
|
self.assertEqual(
|
|
self.mock_vllm_config.compilation_config.level,
|
|
CompilationLevel.NO_COMPILATION,
|
|
)
|
|
|
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
def test_check_and_update_config_cache_config_block_size(
|
|
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
self.mock_vllm_config.cache_config.block_size = None
|
|
self.mock_vllm_config.cache_config.enable_prefix_caching = True
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
|
|
self.assertEqual(self.mock_vllm_config.cache_config.block_size, 128)
|
|
|
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
def test_check_and_update_config_v1_worker_class_selection(
|
|
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
self.mock_vllm_config.parallel_config.worker_cls = "auto"
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
|
|
self.assertEqual(
|
|
self.mock_vllm_config.parallel_config.worker_cls,
|
|
"vllm_ascend.worker.worker_v1.NPUWorker",
|
|
)
|
|
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
@patch("vllm_ascend.utils.is_310p", return_value=True)
|
|
def test_check_and_update_config_310p_no_custom_ops(
|
|
self, mock_is_310p, mock_init_ascend, mock_check_ascend):
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
self.mock_vllm_config.compilation_config.custom_ops = []
|
|
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
self.assertEqual(self.mock_vllm_config.compilation_config.custom_ops,
|
|
[])
|
|
|
|
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
|
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
|
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
|
def test_check_and_update_config_ascend_scheduler_config(
|
|
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
|
self.mock_ascend_config.ascend_scheduler_config.enabled = True
|
|
mock_init_ascend.return_value = self.mock_ascend_config
|
|
|
|
with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig"
|
|
) as mock_scheduler:
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.check_and_update_config(self.mock_vllm_config)
|
|
mock_scheduler.initialize_from_config.assert_called_once()
|
|
|
|
@patch('vllm_ascend.platform.get_ascend_config')
|
|
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
|
|
mock_config = MagicMock()
|
|
mock_config.torchair_graph_config.enabled = False
|
|
|
|
mock_get_ascend_config.return_value = mock_config
|
|
|
|
result = self.platform.get_attn_backend_cls(
|
|
selected_backend="ascend",
|
|
head_size=64,
|
|
dtype="float16",
|
|
kv_cache_dtype="float16",
|
|
block_size=64,
|
|
use_v1=True,
|
|
use_mla=True,
|
|
)
|
|
self.assertEqual(result,
|
|
"vllm_ascend.attention.mla_v1.AscendMLABackend")
|
|
|
|
@patch('vllm_ascend.platform.get_ascend_config')
|
|
def test_get_attn_backend_cls_use_v1_and_torchair(self,
|
|
mock_get_ascend_config):
|
|
mock_config = MagicMock()
|
|
mock_config.torchair_graph_config.enabled = True
|
|
|
|
mock_get_ascend_config.return_value = mock_config
|
|
|
|
result = self.platform.get_attn_backend_cls(
|
|
selected_backend="ascend",
|
|
head_size=64,
|
|
dtype="float16",
|
|
kv_cache_dtype="float16",
|
|
block_size=64,
|
|
use_v1=True,
|
|
use_mla=False,
|
|
)
|
|
self.assertEqual(
|
|
result,
|
|
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
|
|
)
|
|
|
|
@patch('vllm_ascend.platform.get_ascend_config')
|
|
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
|
|
mock_config = MagicMock()
|
|
mock_config.torchair_graph_config.enabled = False
|
|
|
|
mock_get_ascend_config.return_value = mock_config
|
|
|
|
result = self.platform.get_attn_backend_cls(
|
|
selected_backend="ascend",
|
|
head_size=64,
|
|
dtype="float16",
|
|
kv_cache_dtype="float16",
|
|
block_size=64,
|
|
use_v1=True,
|
|
use_mla=False,
|
|
)
|
|
self.assertEqual(
|
|
result,
|
|
"vllm_ascend.attention.attention_v1.AscendAttentionBackend")
|
|
|
|
def test_get_punica_wrapper(self):
|
|
result = self.platform.get_punica_wrapper()
|
|
self.assertEqual(
|
|
result,
|
|
"vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU")
|
|
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
@patch("torch.npu.max_memory_allocated")
|
|
def test_get_current_memory_usage_with_specific_device(
|
|
self, mock_max_memory, mock_reset_stats):
|
|
max_memory_allocated_result = 1024.0
|
|
mock_max_memory.return_value = max_memory_allocated_result
|
|
test_device = torch.device("npu:0")
|
|
result = self.platform.get_current_memory_usage(device=test_device)
|
|
|
|
mock_reset_stats.assert_called_once_with(test_device)
|
|
mock_max_memory.assert_called_once_with(test_device)
|
|
self.assertEqual(result, max_memory_allocated_result)
|
|
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
@patch("torch.npu.max_memory_allocated")
|
|
def test_get_current_memory_usage_with_default_device(
|
|
self, mock_max_memory, mock_reset_stats):
|
|
max_memory_allocated_result = 1024.0
|
|
mock_max_memory.return_value = max_memory_allocated_result
|
|
|
|
result = self.platform.get_current_memory_usage()
|
|
|
|
mock_reset_stats.assert_called_once_with(None)
|
|
mock_max_memory.assert_called_once_with(None)
|
|
self.assertEqual(result, max_memory_allocated_result)
|
|
|
|
@patch("torch.npu.reset_peak_memory_stats",
|
|
side_effect=RuntimeError("Device error"))
|
|
@patch("torch.npu.max_memory_allocated")
|
|
def test_get_current_memory_usage_when_reset_stats_fails(
|
|
self, mock_max_memory, mock_reset_stats):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.get_current_memory_usage()
|
|
mock_reset_stats.assert_called_once()
|
|
mock_max_memory.assert_not_called()
|
|
|
|
@patch("torch.npu.reset_peak_memory_stats")
|
|
@patch(
|
|
"torch.npu.max_memory_allocated",
|
|
side_effect=RuntimeError("Memory query failed"),
|
|
)
|
|
def test_get_current_memory_usage_when_query_fails(self, mock_max_memory,
|
|
mock_reset_stats):
|
|
with self.assertRaises(RuntimeError):
|
|
self.platform.get_current_memory_usage()
|
|
mock_reset_stats.assert_called_once()
|
|
mock_max_memory.assert_called_once()
|
|
|
|
def test_get_device_communicator_cls_returns_correct_value(self):
|
|
self.assertEqual(
|
|
self.platform.get_device_communicator_cls(),
|
|
"vllm_ascend.distributed.communicator.NPUCommunicator",
|
|
)
|
|
|
|
def test_is_pin_memory_available_returns_true(self):
|
|
self.assertTrue(self.platform.is_pin_memory_available())
|
|
|
|
def test_supports_v1(self):
|
|
from vllm.config import ModelConfig
|
|
|
|
mock_config = MagicMock(spec=ModelConfig)
|
|
self.assertTrue(self.platform.supports_v1(mock_config))
|
|
|
|
def test_get_piecewise_backend_cls_returns_correct_value(self):
|
|
self.assertEqual(
|
|
self.platform.get_piecewise_backend_cls(),
|
|
"vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend",
|
|
)
|
|
|
|
@patch("torch.distributed.is_hccl_available", return_value=True)
|
|
@patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL")
|
|
@patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options")
|
|
@patch("torch.distributed.ProcessGroup")
|
|
def test_successful_initialization(self, mock_pg, mock_options_cls,
|
|
mock_pg_hccl, _):
|
|
mock_prefix = MagicMock(spec=PrefixStore)
|
|
mock_options = MagicMock(spec=ProcessGroup.Options)
|
|
mock_options_cls.return_value = mock_options
|
|
mock_backend = MagicMock()
|
|
mock_pg_hccl.return_value = mock_backend
|
|
group_rank = 0
|
|
group_size = 4
|
|
|
|
mock_pg_instance = MagicMock(spec=ProcessGroup)
|
|
mock_pg.return_value = mock_pg_instance
|
|
|
|
# Use importlib.reload() to force-reload the platform module and ensure the mocked ProcessGroup is used.
|
|
# Without this reload, when executing self.platform.stateless_init_device_torch_dist_pg(),
|
|
# it would invoke the original unmocked ProcessGroup implementation instead of our test mock,
|
|
# which would cause the unit test to fail.
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
|
|
result = self.platform.stateless_init_device_torch_dist_pg(
|
|
backend="hccl",
|
|
prefix_store=mock_prefix,
|
|
group_rank=group_rank,
|
|
group_size=group_size,
|
|
timeout=timedelta(seconds=30),
|
|
)
|
|
|
|
mock_pg.assert_called_once_with(mock_prefix, group_rank, group_size,
|
|
unittest.mock.ANY)
|
|
mock_pg_hccl.assert_called_once_with(mock_prefix, group_rank,
|
|
group_size, unittest.mock.ANY)
|
|
mock_backend._set_sequence_number_for_group.assert_called_once()
|
|
mock_pg_instance._register_backend.assert_called_once_with(
|
|
torch.device("npu"), unittest.mock.ANY, mock_backend)
|
|
self.assertEqual(result, mock_pg_instance)
|
|
|
|
@patch("torch.distributed.is_hccl_available", return_value=False)
|
|
def test_hccl_unavailable(self, _):
|
|
with self.assertRaises(AssertionError):
|
|
from vllm_ascend import platform
|
|
|
|
importlib.reload(platform)
|
|
self.platform.stateless_init_device_torch_dist_pg(
|
|
backend="hccl",
|
|
prefix_store=MagicMock(),
|
|
group_rank=0,
|
|
group_size=4,
|
|
timeout=timedelta(seconds=30),
|
|
)
|