import importlib import unittest from datetime import timedelta from unittest.mock import MagicMock, patch import torch from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import PrefixStore from vllm.config import CompilationLevel from vllm.platforms import PlatformEnum from tests.ut.base import TestBase from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD class TestNPUPlatform(TestBase): def setUp(self): self.platform = NPUPlatform() self.mock_vllm_config = MagicMock() self.mock_vllm_config.compilation_config = MagicMock() self.mock_vllm_config.model_config = MagicMock() self.mock_vllm_config.parallel_config = MagicMock() self.mock_vllm_config.cache_config = MagicMock() self.mock_vllm_config.scheduler_config = MagicMock() self.mock_vllm_config.speculative_config = None self.mock_ascend_config = MagicMock() self.mock_ascend_config.expert_tensor_parallel_size = 0 self.mock_ascend_config.torchair_graph_config.enabled = False self.mock_ascend_config.ascend_scheduler_config.enabled = False def test_class_variables(self): self.assertEqual(NPUPlatform._enum, PlatformEnum.OOT) self.assertEqual(NPUPlatform.device_name, "npu") self.assertEqual(NPUPlatform.device_type, "npu") self.assertEqual(NPUPlatform.simple_compile_backend, "eager") self.assertEqual(NPUPlatform.ray_device_key, "NPU") self.assertEqual(NPUPlatform.device_control_env_var, "ASCEND_RT_VISIBLE_DEVICES") self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1") self.assertEqual(NPUPlatform.supported_quantization, [ASCEND_QUATIZATION_METHOD]) def test_is_sleep_mode_available(self): self.assertTrue(self.platform.is_sleep_mode_available()) @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") def test_pre_register_and_update_with_parser(self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() mock_action = MagicMock() mock_action.choices = ["awq", "gptq"] mock_parser._option_string_actions = {"--quantization": mock_action} self.platform.pre_register_and_update(mock_parser) mock_adapt_patch.assert_called_once_with(is_global_patch=True) self.assertTrue(ASCEND_QUATIZATION_METHOD in mock_action.choices) self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") def test_pre_register_and_update_without_parser(self, mock_quant_config, mock_adapt_patch): self.platform.pre_register_and_update(None) mock_adapt_patch.assert_called_once_with(is_global_patch=True) @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") def test_pre_register_and_update_with_parser_no_quant_action( self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() mock_parser._option_string_actions = {} self.platform.pre_register_and_update(mock_parser) mock_adapt_patch.assert_called_once_with(is_global_patch=True) @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") def test_pre_register_and_update_with_existing_ascend_quant( self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() mock_action = MagicMock() mock_action.choices = ["awq", ASCEND_QUATIZATION_METHOD] mock_parser._option_string_actions = {"--quantization": mock_action} self.platform.pre_register_and_update(mock_parser) mock_adapt_patch.assert_called_once_with(is_global_patch=True) self.assertEqual(len(mock_action.choices), 2) def test_get_device_capability(self): self.assertIsNone(self.platform.get_device_capability(device_id=0)) @patch("torch.npu.get_device_name") def test_get_device_name(self, mock_get_device_name): device_id = 0 device_name = "Ascend910B2" mock_get_device_name.return_value = device_name self.assertEqual(self.platform.get_device_name(device_id), device_name) mock_get_device_name.assert_called_once_with(0) def test_is_async_output_supported(self): self.assertTrue( self.platform.is_async_output_supported(enforce_eager=None)) self.assertTrue( self.platform.is_async_output_supported(enforce_eager=True)) self.assertTrue( self.platform.is_async_output_supported(enforce_eager=False)) @patch("torch.inference_mode") def test_inference_mode(self, mock_inference_mode): mock_inference_mode.return_value = None self.assertIsNone(self.platform.inference_mode()) mock_inference_mode.assert_called_once() @patch("torch.npu.set_device") def test_set_device_normal(self, mock_set_device): device = torch.device("npu:0") self.platform.set_device(device) mock_set_device.assert_called_once_with(device) @patch("torch.npu.set_device", side_effect=RuntimeError("Device not available")) def test_set_device_failure(self, mock_set_device): device = torch.device("npu:0") with self.assertRaises(RuntimeError): self.platform.set_device(device) mock_set_device.assert_called_once_with(device) @patch("torch.npu.empty_cache") def test_empty_cache_normal(self, mock_empty_cache): self.platform.empty_cache() mock_empty_cache.assert_called_once() @patch("torch.npu.empty_cache", side_effect=RuntimeError("Cache clearing failed")) def test_empty_cache_failure(self, mock_empty_cache): with self.assertRaises(RuntimeError): self.platform.empty_cache() mock_empty_cache.assert_called_once() @patch("torch.npu.synchronize") def test_synchronize_normal(self, mock_synchronize): self.platform.synchronize() mock_synchronize.assert_called_once() @patch("torch.npu.synchronize", side_effect=RuntimeError("Synchronization failed")) def test_synchronize_failure(self, mock_synchronize): with self.assertRaises(RuntimeError): self.platform.synchronize() mock_synchronize.assert_called_once() @patch("torch.npu.mem_get_info") def test_mem_get_info_normal(self, mock_mem_get_info): free_memory_size = 1024 total_memory_size = 2048 memory_info = (free_memory_size, total_memory_size) mock_mem_get_info.return_value = memory_info result = self.platform.mem_get_info() self.assertIsInstance(result, tuple) self.assertEqual(len(result), 2) self.assertEqual(result, memory_info) mock_mem_get_info.assert_called_once() @patch("torch.npu.mem_get_info", side_effect=RuntimeError("NPU not available")) def test_mem_get_info_failure(self, mock_mem_get_info): with self.assertRaises(RuntimeError): self.platform.mem_get_info() mock_mem_get_info.assert_called_once() @patch("gc.collect") @patch("torch.npu.empty_cache") @patch("torch.npu.reset_peak_memory_stats") def test_clear_npu_memory_normal(self, mock_reset_stats, mock_empty_cache, mock_gc_collect): self.platform.clear_npu_memory() mock_gc_collect.assert_called_once() mock_empty_cache.assert_called_once() mock_reset_stats.assert_called_once() @patch("gc.collect", side_effect=Exception("GC failed")) @patch("torch.npu.empty_cache") @patch("torch.npu.reset_peak_memory_stats") def test_clear_npu_memory_gc_collect_failure(self, mock_reset_stats, mock_empty_cache, mock_gc_collect): with self.assertRaises(Exception): self.platform.clear_npu_memory() mock_gc_collect.assert_called_once() mock_empty_cache.assert_not_called() mock_reset_stats.assert_not_called() @patch("gc.collect") @patch("torch.npu.empty_cache", side_effect=RuntimeError("Cache clear failed")) @patch("torch.npu.reset_peak_memory_stats") def test_clear_npu_memory_empty_cache_failure(self, mock_reset_stats, mock_empty_cache, mock_gc_collect): with self.assertRaises(RuntimeError): self.platform.clear_npu_memory() mock_gc_collect.assert_called_once() mock_empty_cache.assert_called_once() mock_reset_stats.assert_not_called() @patch("gc.collect") @patch("torch.npu.empty_cache") @patch("torch.npu.reset_peak_memory_stats", side_effect=RuntimeError("Reset failed")) def test_clear_npu_memory_reset_stats_failure(self, mock_reset_stats, mock_empty_cache, mock_gc_collect): with self.assertRaises(RuntimeError): self.platform.clear_npu_memory() mock_gc_collect.assert_called_once() mock_empty_cache.assert_called_once() mock_reset_stats.assert_called_once() @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.utils.update_aclgraph_sizes") @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("os.environ", {}) def test_check_and_update_config_basic_config_update( self, mock_is_310p, mock_update_acl, mock_init_ascend, mock_check_ascend): mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.parallel_config.enable_expert_parallel = False # Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used. # Without this reload, when calling self.platform.check_and_update_config, # it would execute the original unmocked init_ascend_config method, causing the unit test to fail. from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) mock_init_ascend.assert_called_once_with(self.mock_vllm_config) mock_check_ascend.assert_called_once() @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") def test_check_and_update_config_expert_parallel_enabled( self, mock_init_ascend, mock_check_ascend, mock_is_310p): mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.parallel_config.enable_expert_parallel = True self.mock_vllm_config.parallel_config.tensor_parallel_size = 2 self.mock_vllm_config.parallel_config.world_size_across_dp = 4 from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) self.assertEqual( self.mock_vllm_config.parallel_config.expert_tensor_parallel_size, 1) self.assertEqual( self.mock_vllm_config.parallel_config.expert_parallel_size, self.mock_vllm_config.parallel_config.world_size_across_dp, ) @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") def test_check_and_update_config_no_model_config_warning( self, mock_init_ascend, mock_check_ascend, mock_is_310p): mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.model_config = None with self.assertLogs(logger="vllm", level="WARNING") as cm: from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) self.assertTrue("Model config is missing" in cm.output[0]) @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") def test_check_and_update_config_enforce_eager_mode( self, mock_init_ascend, mock_check_ascend, mock_is_310p): mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.model_config.enforce_eager = True with self.assertLogs(logger="vllm", level="INFO") as cm: from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) self.assertTrue("Compilation disabled, using eager mode by default" in cm.output[0]) self.assertEqual( self.mock_vllm_config.compilation_config.level, CompilationLevel.NO_COMPILATION, ) @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") def test_check_and_update_config_unsupported_compilation_level( self, mock_init_ascend, mock_check_ascend, mock_is_310p): mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.model_config.enforce_eager = False self.mock_vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE with self.assertLogs(logger="vllm", level="WARNING") as cm: from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) self.assertTrue("NPU does not support" in cm.output[0]) self.assertEqual( self.mock_vllm_config.compilation_config.level, CompilationLevel.NO_COMPILATION, ) @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") def test_check_and_update_config_torchair_enabled_compilation( self, mock_init_ascend, mock_check_ascend, mock_is_310p): self.mock_ascend_config.torchair_graph_config.enabled = True mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.model_config.enforce_eager = False self.mock_vllm_config.compilation_config.level = CompilationLevel.PIECEWISE with self.assertLogs(logger="vllm", level="INFO") as cm: from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) self.assertTrue("Torchair compilation enabled" in cm.output[0]) self.assertEqual( self.mock_vllm_config.compilation_config.level, CompilationLevel.NO_COMPILATION, ) @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") def test_check_and_update_config_cache_config_block_size( self, mock_init_ascend, mock_check_ascend, mock_is_310p): mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.cache_config.block_size = None self.mock_vllm_config.cache_config.enable_prefix_caching = True from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) self.assertEqual(self.mock_vllm_config.cache_config.block_size, 128) @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") def test_check_and_update_config_v1_worker_class_selection( self, mock_init_ascend, mock_check_ascend, mock_is_310p): mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.parallel_config.worker_cls = "auto" from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) self.assertEqual( self.mock_vllm_config.parallel_config.worker_cls, "vllm_ascend.worker.worker_v1.NPUWorker", ) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.utils.is_310p", return_value=True) def test_check_and_update_config_310p_no_custom_ops( self, mock_is_310p, mock_init_ascend, mock_check_ascend): mock_init_ascend.return_value = self.mock_ascend_config self.mock_vllm_config.compilation_config.custom_ops = [] from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) self.assertEqual(self.mock_vllm_config.compilation_config.custom_ops, []) @patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config") def test_check_and_update_config_ascend_scheduler_config( self, mock_init_ascend, mock_check_ascend, mock_is_310p): self.mock_ascend_config.ascend_scheduler_config.enabled = True mock_init_ascend.return_value = self.mock_ascend_config with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig" ) as mock_scheduler: from vllm_ascend import platform importlib.reload(platform) self.platform.check_and_update_config(self.mock_vllm_config) mock_scheduler.initialize_from_config.assert_called_once() @patch('vllm_ascend.platform.get_ascend_config') def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False mock_get_ascend_config.return_value = mock_config result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, dtype="float16", kv_cache_dtype="float16", block_size=64, use_v1=True, use_mla=True, ) self.assertEqual(result, "vllm_ascend.attention.mla_v1.AscendMLABackend") @patch('vllm_ascend.platform.get_ascend_config') def test_get_attn_backend_cls_use_v1_and_torchair(self, mock_get_ascend_config): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = True mock_get_ascend_config.return_value = mock_config result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, dtype="float16", kv_cache_dtype="float16", block_size=64, use_v1=True, use_mla=False, ) self.assertEqual( result, "vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend" ) @patch('vllm_ascend.platform.get_ascend_config') def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False mock_get_ascend_config.return_value = mock_config result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, dtype="float16", kv_cache_dtype="float16", block_size=64, use_v1=True, use_mla=False, ) self.assertEqual( result, "vllm_ascend.attention.attention_v1.AscendAttentionBackend") @patch('vllm_ascend.platform.get_ascend_config') def test_get_attn_backend_cls_use_mla_only(self, mock_get_ascend_config): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False mock_get_ascend_config.return_value = mock_config result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, dtype="float16", kv_cache_dtype="float16", block_size=64, use_v1=False, use_mla=True, ) self.assertEqual( result, "vllm_ascend.attention.attention.AscendMLAAttentionBackend") @patch('vllm_ascend.platform.get_ascend_config') def test_get_attn_backend_cls_default_case(self, mock_get_ascend_config): mock_config = MagicMock() mock_config.torchair_graph_config.enabled = False mock_get_ascend_config.return_value = mock_config result = self.platform.get_attn_backend_cls( selected_backend="ascend", head_size=64, dtype="float16", kv_cache_dtype="float16", block_size=64, use_v1=False, use_mla=False, ) self.assertEqual( result, "vllm_ascend.attention.attention.AscendAttentionBackend") def test_get_punica_wrapper(self): result = self.platform.get_punica_wrapper() self.assertEqual( result, "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU") @patch("torch.npu.reset_peak_memory_stats") @patch("torch.npu.max_memory_allocated") def test_get_current_memory_usage_with_specific_device( self, mock_max_memory, mock_reset_stats): max_memory_allocated_result = 1024.0 mock_max_memory.return_value = max_memory_allocated_result test_device = torch.device("npu:0") result = self.platform.get_current_memory_usage(device=test_device) mock_reset_stats.assert_called_once_with(test_device) mock_max_memory.assert_called_once_with(test_device) self.assertEqual(result, max_memory_allocated_result) @patch("torch.npu.reset_peak_memory_stats") @patch("torch.npu.max_memory_allocated") def test_get_current_memory_usage_with_default_device( self, mock_max_memory, mock_reset_stats): max_memory_allocated_result = 1024.0 mock_max_memory.return_value = max_memory_allocated_result result = self.platform.get_current_memory_usage() mock_reset_stats.assert_called_once_with(None) mock_max_memory.assert_called_once_with(None) self.assertEqual(result, max_memory_allocated_result) @patch("torch.npu.reset_peak_memory_stats", side_effect=RuntimeError("Device error")) @patch("torch.npu.max_memory_allocated") def test_get_current_memory_usage_when_reset_stats_fails( self, mock_max_memory, mock_reset_stats): with self.assertRaises(RuntimeError): self.platform.get_current_memory_usage() mock_reset_stats.assert_called_once() mock_max_memory.assert_not_called() @patch("torch.npu.reset_peak_memory_stats") @patch( "torch.npu.max_memory_allocated", side_effect=RuntimeError("Memory query failed"), ) def test_get_current_memory_usage_when_query_fails(self, mock_max_memory, mock_reset_stats): with self.assertRaises(RuntimeError): self.platform.get_current_memory_usage() mock_reset_stats.assert_called_once() mock_max_memory.assert_called_once() def test_get_device_communicator_cls_returns_correct_value(self): self.assertEqual( self.platform.get_device_communicator_cls(), "vllm_ascend.distributed.communicator.NPUCommunicator", ) def test_is_pin_memory_available_returns_true(self): self.assertTrue(self.platform.is_pin_memory_available()) def test_supports_v1(self): from vllm.config import ModelConfig mock_config = MagicMock(spec=ModelConfig) self.assertTrue(self.platform.supports_v1(mock_config)) def test_get_piecewise_backend_cls_returns_correct_value(self): self.assertEqual( self.platform.get_piecewise_backend_cls(), "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend", ) @patch("torch.distributed.is_hccl_available", return_value=True) @patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL") @patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options") @patch("torch.distributed.ProcessGroup") def test_successful_initialization(self, mock_pg, mock_options_cls, mock_pg_hccl, _): mock_prefix = MagicMock(spec=PrefixStore) mock_options = MagicMock(spec=ProcessGroup.Options) mock_options_cls.return_value = mock_options mock_backend = MagicMock() mock_pg_hccl.return_value = mock_backend group_rank = 0 group_size = 4 mock_pg_instance = MagicMock(spec=ProcessGroup) mock_pg.return_value = mock_pg_instance # Use importlib.reload() to force-reload the platform module and ensure the mocked ProcessGroup is used. # Without this reload, when executing self.platform.stateless_init_device_torch_dist_pg(), # it would invoke the original unmocked ProcessGroup implementation instead of our test mock, # which would cause the unit test to fail. from vllm_ascend import platform importlib.reload(platform) result = self.platform.stateless_init_device_torch_dist_pg( backend="hccl", prefix_store=mock_prefix, group_rank=group_rank, group_size=group_size, timeout=timedelta(seconds=30), ) mock_pg.assert_called_once_with(mock_prefix, group_rank, group_size, unittest.mock.ANY) mock_pg_hccl.assert_called_once_with(mock_prefix, group_rank, group_size, unittest.mock.ANY) mock_backend._set_sequence_number_for_group.assert_called_once() mock_pg_instance._register_backend.assert_called_once_with( torch.device("npu"), unittest.mock.ANY, mock_backend) self.assertEqual(result, mock_pg_instance) @patch("torch.distributed.is_hccl_available", return_value=False) def test_hccl_unavailable(self, _): with self.assertRaises(AssertionError): from vllm_ascend import platform importlib.reload(platform) self.platform.stateless_init_device_torch_dist_pg( backend="hccl", prefix_store=MagicMock(), group_rank=0, group_size=4, timeout=timedelta(seconds=30), )