diff --git a/tests/ut/eplb/core/test_eplb_utils.py b/tests/ut/eplb/core/test_eplb_utils.py index bc112c89..a20fa893 100644 --- a/tests/ut/eplb/core/test_eplb_utils.py +++ b/tests/ut/eplb/core/test_eplb_utils.py @@ -5,9 +5,7 @@ from unittest.mock import patch # isort: off import torch from vllm.config import VllmConfig -from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, - FusedMoEParallelConfig - ) +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig, FusedMoEParallelConfig from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.eplb.core.eplb_utils import init_eplb_config @@ -15,30 +13,23 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config class TestAscendConfig(unittest.TestCase): - - def setUp(self): + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def setUp(self, mock_fix_incompatible_config): vllm_config = VllmConfig() vllm_config.additional_config = { "refresh": True, - "eplb_config": { - "dynamic_eplb": True, - "num_redundant_experts": 2 - } + "eplb_config": {"dynamic_eplb": True, "num_redundant_experts": 2}, } - moe_parallel_config = FusedMoEParallelConfig(2, 0, 1, 2, 1, 1, 1, 1, - True, "hccl") - moe_config = FusedMoEConfig(8, 8, 8192, 5, moe_parallel_config, - torch.float16) + moe_parallel_config = FusedMoEParallelConfig(2, 0, 1, 2, 1, 1, 1, 1, True, "hccl") + moe_config = FusedMoEConfig(8, 8, 8192, 5, moe_parallel_config, torch.float16) moe_config.supports_eplb = True self.vllm_config = vllm_config self.moe_config = moe_config - self.mock_npu = patch("torch.Tensor.npu", - new=lambda self: self).start() + self.mock_npu = patch("torch.Tensor.npu", new=lambda self: self).start() def test_init_eplb_config_with_eplb(self): eplb_config = init_ascend_config(self.vllm_config).eplb_config - _, expert_map, log2phy, redundant_experts = init_eplb_config( - eplb_config, 0, self.moe_config) + _, expert_map, log2phy, redundant_experts = init_eplb_config(eplb_config, 0, self.moe_config) gt_expert_map = torch.tensor([4, -1, -1, -1, 0, 1, 2, 3]) gt_log2phy = torch.tensor([9, 1, 2, 3, 5, 6, 7, 8]) self.assertTrue(torch.equal(expert_map, gt_expert_map)) @@ -47,11 +38,9 @@ class TestAscendConfig(unittest.TestCase): def test_init_eplb_config_with_eplb_withmap(self): _TEST_DIR = os.path.dirname(__file__) - self.vllm_config.additional_config["eplb_config"][ - "expert_map_path"] = _TEST_DIR + "/expert_map.json" + self.vllm_config.additional_config["eplb_config"]["expert_map_path"] = _TEST_DIR + "/expert_map.json" eplb_config = init_ascend_config(self.vllm_config).eplb_config - _, expert_map, log2phy, redundant_experts = init_eplb_config( - eplb_config, 0, self.moe_config) + _, expert_map, log2phy, redundant_experts = init_eplb_config(eplb_config, 0, self.moe_config) gt_expert_map = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3]) gt_log2phy = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8]) self.assertTrue(torch.equal(expert_map, gt_expert_map)) @@ -61,8 +50,7 @@ class TestAscendConfig(unittest.TestCase): def test_init_eplb_config_without_eplb(self): self.vllm_config.additional_config = {"refresh": True} eplb_config = init_ascend_config(self.vllm_config).eplb_config - _, expert_map, log2phy, redundant_experts = init_eplb_config( - eplb_config, 0, self.moe_config) + _, expert_map, log2phy, redundant_experts = init_eplb_config(eplb_config, 0, self.moe_config) gt_expert_map = torch.tensor([-1, -1, -1, -1, 0, 1, 2, 3]) print(expert_map, log2phy, redundant_experts) self.assertTrue(torch.equal(expert_map, gt_expert_map)) diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index c08e75e0..c7c9f52d 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -13,18 +13,17 @@ # This file is a part of the vllm-ascend project. # +from unittest.mock import patch + from vllm.config import VllmConfig from tests.ut.base import TestBase -from vllm_ascend.ascend_config import (clear_ascend_config, get_ascend_config, - init_ascend_config) +from vllm_ascend.ascend_config import clear_ascend_config, get_ascend_config, init_ascend_config class TestAscendConfig(TestBase): - @staticmethod def _clean_up_ascend_config(func): - def wrapper(*args, **kwargs): clear_ascend_config() func(*args, **kwargs) @@ -33,7 +32,8 @@ class TestAscendConfig(TestBase): return wrapper @_clean_up_ascend_config - def test_init_ascend_config_without_additional_config(self): + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def test_init_ascend_config_without_additional_config(self, mock_fix_incompatible_config): test_vllm_config = VllmConfig() # No additional config given, check the default value here. ascend_config = init_ascend_config(test_vllm_config) @@ -47,7 +47,8 @@ class TestAscendConfig(TestBase): self.assertTrue(ascend_fusion_config.fusion_ops_gmmswigluquant) @_clean_up_ascend_config - def test_init_ascend_config_with_additional_config(self): + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def test_init_ascend_config_with_additional_config(self, mock_fix_incompatible_config): test_vllm_config = VllmConfig() test_vllm_config.additional_config = { "ascend_compilation_config": { @@ -57,11 +58,9 @@ class TestAscendConfig(TestBase): "fusion_ops_gmmswigluquant": False, }, "multistream_overlap_shared_expert": True, - "eplb_config": { - "num_redundant_experts": 2 - }, + "eplb_config": {"num_redundant_experts": 2}, "refresh": True, - "enable_kv_nz": False + "enable_kv_nz": False, } ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.eplb_config.num_redundant_experts, 2) @@ -76,7 +75,8 @@ class TestAscendConfig(TestBase): self.assertFalse(ascend_fusion_config.fusion_ops_gmmswigluquant) @_clean_up_ascend_config - def test_init_ascend_config_enable_npugraph_ex(self): + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def test_init_ascend_config_enable_npugraph_ex(self, mock_fix_incompatible_config): test_vllm_config = VllmConfig() test_vllm_config.additional_config = { "enable_npugraph_ex": True, @@ -86,7 +86,8 @@ class TestAscendConfig(TestBase): self.assertTrue(ascend_config.enable_npugraph_ex) @_clean_up_ascend_config - def test_get_ascend_config(self): + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def test_get_ascend_config(self, mock_fix_incompatible_config): test_vllm_config = VllmConfig() ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(get_ascend_config(), ascend_config) @@ -97,7 +98,8 @@ class TestAscendConfig(TestBase): get_ascend_config() @_clean_up_ascend_config - def test_clear_ascend_config(self): + @patch("vllm_ascend.platform.NPUPlatform._fix_incompatible_config") + def test_clear_ascend_config(self, mock_fix_incompatible_config): test_vllm_config = VllmConfig() ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(get_ascend_config(), ascend_config) diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 58919331..ba61744b 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -8,12 +8,10 @@ from vllm.platforms import PlatformEnum from tests.ut.base import TestBase from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, - COMPRESSED_TENSORS_METHOD, AscendDeviceType, - vllm_version_is) +from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD, AscendDeviceType, vllm_version_is # isort: off -if vllm_version_is('0.13.0'): +if vllm_version_is("0.13.0"): from vllm.attention.selector import AttentionSelectorConfig # type: ignore else: from vllm.v1.attention.selector import AttentionSelectorConfig # type: ignore @@ -21,7 +19,6 @@ else: class TestNPUPlatform(TestBase): - @staticmethod def mock_vllm_config(): mock_vllm_config = MagicMock() @@ -44,9 +41,7 @@ class TestNPUPlatform(TestBase): def setUp(self): self.platform = NPUPlatform() - self.platform.supported_quantization[:] = [ - "ascend", "compressed-tensors" - ] + self.platform.supported_quantization[:] = ["ascend", "compressed-tensors"] def test_class_variables(self): self.assertEqual(NPUPlatform._enum, PlatformEnum.OOT) @@ -54,20 +49,16 @@ class TestNPUPlatform(TestBase): self.assertEqual(NPUPlatform.device_type, "npu") self.assertEqual(NPUPlatform.simple_compile_backend, "eager") self.assertEqual(NPUPlatform.ray_device_key, "NPU") - self.assertEqual(NPUPlatform.device_control_env_var, - "ASCEND_RT_VISIBLE_DEVICES") + self.assertEqual(NPUPlatform.device_control_env_var, "ASCEND_RT_VISIBLE_DEVICES") self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1") - self.assertEqual( - NPUPlatform.supported_quantization, - [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD]) + self.assertEqual(NPUPlatform.supported_quantization, [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD]) def test_is_sleep_mode_available(self): self.assertTrue(self.platform.is_sleep_mode_available()) @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") - def test_pre_register_and_update_with_parser(self, mock_quant_config, - mock_adapt_patch): + def test_pre_register_and_update_with_parser(self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() mock_action = MagicMock() mock_action.choices = ["awq", "gptq"] @@ -82,16 +73,14 @@ class TestNPUPlatform(TestBase): @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") - def test_pre_register_and_update_without_parser(self, mock_quant_config, - mock_adapt_patch): + def test_pre_register_and_update_without_parser(self, mock_quant_config, mock_adapt_patch): self.platform.pre_register_and_update(None) mock_adapt_patch.assert_called_once_with(is_global_patch=True) @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") - def test_pre_register_and_update_with_parser_no_quant_action( - self, mock_quant_config, mock_adapt_patch): + def test_pre_register_and_update_with_parser_no_quant_action(self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() mock_parser._option_string_actions = {} @@ -101,8 +90,7 @@ class TestNPUPlatform(TestBase): @patch("vllm_ascend.utils.adapt_patch") @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") - def test_pre_register_and_update_with_existing_ascend_quant( - self, mock_quant_config, mock_adapt_patch): + def test_pre_register_and_update_with_existing_ascend_quant(self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() mock_action = MagicMock() mock_action.choices = ["awq", ASCEND_QUANTIZATION_METHOD] @@ -132,17 +120,13 @@ class TestNPUPlatform(TestBase): @patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.utils.update_aclgraph_sizes") - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType.A3) + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("os.environ", {}) - @patch( - "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" - ) + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_basic_config_update( - self, mock_init_recompute, mock_soc_version, mock_update_acl, - mock_init_ascend): - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) + self, mock_init_recompute, mock_soc_version, mock_update_acl, mock_init_ascend + ): + mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.parallel_config.enable_expert_parallel = False vllm_config.parallel_config.decode_context_parallel_size = 1 @@ -164,16 +148,13 @@ class TestNPUPlatform(TestBase): mock_init_ascend.assert_called_once_with(vllm_config) - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType.A3) + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") - @patch( - "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" - ) + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_no_model_config_warning( - self, mock_init_recompute, mock_init_ascend, mock_soc_version): - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) + self, mock_init_recompute, mock_init_ascend, mock_soc_version + ): + mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config = None vllm_config.parallel_config.decode_context_parallel_size = 1 @@ -186,19 +167,18 @@ class TestNPUPlatform(TestBase): from vllm_ascend import platform importlib.reload(platform) - self.platform.check_and_update_config(vllm_config) + self.platform = platform.NPUPlatform() + + with patch.object(platform.NPUPlatform, "_fix_incompatible_config"): + self.platform.check_and_update_config(vllm_config) + self.assertTrue("Model config is missing" in cm.output[0]) - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType.A3) + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") - @patch( - "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" - ) - def test_check_and_update_config_enforce_eager_mode( - self, mock_init_recompute, mock_init_ascend, mock_soc_version): - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_enforce_eager_mode(self, mock_init_recompute, mock_init_ascend, mock_soc_version): + mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.enforce_eager = True vllm_config.parallel_config.decode_context_parallel_size = 1 @@ -211,9 +191,12 @@ class TestNPUPlatform(TestBase): from vllm_ascend import platform importlib.reload(platform) - self.platform.check_and_update_config(vllm_config) - self.assertTrue("Compilation disabled, using eager mode by default" in - cm.output[0]) + self.platform = platform.NPUPlatform() + + with patch.object(platform.NPUPlatform, "_fix_incompatible_config"): + self.platform.check_and_update_config(vllm_config) + + self.assertTrue("Compilation disabled, using eager mode by default" in cm.output[0]) self.assertEqual( vllm_config.compilation_config.mode, @@ -225,19 +208,15 @@ class TestNPUPlatform(TestBase): CUDAGraphMode.NONE, ) - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType.A3) + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.utils.update_default_aclgraph_sizes") @patch("vllm_ascend.ascend_config.init_ascend_config") - @patch( - "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" - ) + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_unsupported_compilation_level( - self, mock_init_recompute, mock_init_ascend, mock_update_default, - mock_soc_version): + self, mock_init_recompute, mock_init_ascend, mock_update_default, mock_soc_version + ): mock_update_default.return_value = MagicMock() - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) + mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.enforce_eager = False vllm_config.parallel_config.decode_context_parallel_size = 1 @@ -252,7 +231,11 @@ class TestNPUPlatform(TestBase): from vllm_ascend import platform importlib.reload(platform) - self.platform.check_and_update_config(vllm_config) + self.platform = platform.NPUPlatform() + + with patch.object(platform.NPUPlatform, "_fix_incompatible_config"): + self.platform.check_and_update_config(vllm_config) + self.assertTrue("NPU does not support" in cm.output[0]) self.assertEqual( @@ -264,15 +247,11 @@ class TestNPUPlatform(TestBase): CUDAGraphMode.NONE, ) - @pytest.mark.skip( - "Revert me when vllm support setting cudagraph_mode on oot platform") - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType.A3) + @pytest.mark.skip("Revert me when vllm support setting cudagraph_mode on oot platform") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") - def test_check_and_update_config_unsupported_cudagraph_mode( - self, mock_init_ascend, mock_soc_version): - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) + def test_check_and_update_config_unsupported_cudagraph_mode(self, mock_init_ascend, mock_soc_version): + mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.enforce_eager = False vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.FULL @@ -282,9 +261,7 @@ class TestNPUPlatform(TestBase): importlib.reload(platform) self.platform.check_and_update_config(vllm_config) - self.assertTrue( - "cudagraph_mode is not support on NPU. falling back to NONE" in - cm.output[0]) + self.assertTrue("cudagraph_mode is not support on NPU. falling back to NONE" in cm.output[0]) self.assertEqual( vllm_config.compilation_config.mode, @@ -295,16 +272,13 @@ class TestNPUPlatform(TestBase): CUDAGraphMode.NONE, ) - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType.A3) + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") - @patch( - "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" - ) + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_cache_config_block_size( - self, mock_init_recompute, mock_init_ascend, mock_soc_version): - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) + self, mock_init_recompute, mock_init_ascend, mock_soc_version + ): + mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.cache_config.block_size = None vllm_config.cache_config.enable_prefix_caching = True @@ -322,16 +296,13 @@ class TestNPUPlatform(TestBase): self.assertEqual(vllm_config.cache_config.block_size, 128) - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType.A3) + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") - @patch( - "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" - ) + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") def test_check_and_update_config_v1_worker_class_selection( - self, mock_init_recompute, mock_init_ascend, mock_soc_version): - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) + self, mock_init_recompute, mock_init_ascend, mock_soc_version + ): + mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.parallel_config.worker_cls = "auto" vllm_config.parallel_config.decode_context_parallel_size = 1 @@ -361,15 +332,10 @@ class TestNPUPlatform(TestBase): ) @patch("vllm_ascend.ascend_config.init_ascend_config") - @patch('vllm_ascend.utils.get_ascend_device_type', - return_value=AscendDeviceType._310P) - @patch( - "vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config" - ) - def test_check_and_update_config_310p_no_custom_ops( - self, mock_init_recompute, mock_soc_version, mock_init_ascend): - mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config( - ) + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType._310P) + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_310p_no_custom_ops(self, mock_init_recompute, mock_soc_version, mock_init_ascend): + mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config() vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.compilation_config.custom_ops = [] vllm_config.parallel_config.decode_context_parallel_size = 1 @@ -394,10 +360,8 @@ class TestNPUPlatform(TestBase): use_mla=True, use_sparse=False, ) - result = self.platform.get_attn_backend_cls("ascend", - attn_selector_config) - self.assertEqual(result, - "vllm_ascend.attention.mla_v1.AscendMLABackend") + result = self.platform.get_attn_backend_cls("ascend", attn_selector_config) + self.assertEqual(result, "vllm_ascend.attention.mla_v1.AscendMLABackend") def test_get_attn_backend_cls_use_v1_only(self): attn_selector_config = AttentionSelectorConfig( @@ -408,22 +372,17 @@ class TestNPUPlatform(TestBase): use_mla=False, use_sparse=False, ) - result = self.platform.get_attn_backend_cls("ascend", - attn_selector_config) - self.assertEqual( - result, - "vllm_ascend.attention.attention_v1.AscendAttentionBackend") + result = self.platform.get_attn_backend_cls("ascend", attn_selector_config) + self.assertEqual(result, "vllm_ascend.attention.attention_v1.AscendAttentionBackend") def test_get_punica_wrapper(self): result = self.platform.get_punica_wrapper() - self.assertEqual(result, - "vllm_ascend.lora.punica_npu.PunicaWrapperNPU") + self.assertEqual(result, "vllm_ascend.lora.punica_npu.PunicaWrapperNPU") @patch("torch.npu.reset_peak_memory_stats") @patch("torch.npu.max_memory_allocated") - def test_get_current_memory_usage_with_specific_device( - self, mock_max_memory, mock_reset_stats): + def test_get_current_memory_usage_with_specific_device(self, mock_max_memory, mock_reset_stats): max_memory_allocated_result = 1024.0 mock_max_memory.return_value = max_memory_allocated_result test_device = torch.device("npu:0") @@ -435,8 +394,7 @@ class TestNPUPlatform(TestBase): @patch("torch.npu.reset_peak_memory_stats") @patch("torch.npu.max_memory_allocated") - def test_get_current_memory_usage_with_default_device( - self, mock_max_memory, mock_reset_stats): + def test_get_current_memory_usage_with_default_device(self, mock_max_memory, mock_reset_stats): max_memory_allocated_result = 1024.0 mock_max_memory.return_value = max_memory_allocated_result @@ -446,11 +404,9 @@ class TestNPUPlatform(TestBase): mock_max_memory.assert_called_once_with(None) self.assertEqual(result, max_memory_allocated_result) - @patch("torch.npu.reset_peak_memory_stats", - side_effect=RuntimeError("Device error")) + @patch("torch.npu.reset_peak_memory_stats", side_effect=RuntimeError("Device error")) @patch("torch.npu.max_memory_allocated") - def test_get_current_memory_usage_when_reset_stats_fails( - self, mock_max_memory, mock_reset_stats): + def test_get_current_memory_usage_when_reset_stats_fails(self, mock_max_memory, mock_reset_stats): with self.assertRaises(RuntimeError): self.platform.get_current_memory_usage() mock_reset_stats.assert_called_once() @@ -461,8 +417,7 @@ class TestNPUPlatform(TestBase): "torch.npu.max_memory_allocated", side_effect=RuntimeError("Memory query failed"), ) - def test_get_current_memory_usage_when_query_fails(self, mock_max_memory, - mock_reset_stats): + def test_get_current_memory_usage_when_query_fails(self, mock_max_memory, mock_reset_stats): with self.assertRaises(RuntimeError): self.platform.get_current_memory_usage() mock_reset_stats.assert_called_once() diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b9ac4404..bf2c6c13 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -171,6 +171,7 @@ class NPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # initialize ascend config from vllm additional_config + cls._fix_incompatible_config(vllm_config) ascend_config = init_ascend_config(vllm_config) if vllm_config.kv_transfer_config is not None: @@ -620,3 +621,145 @@ class NPUPlatform(Platform): "max_tokens_across_dp": max_tokens_across_dp, "mc2_mask": mc2_mask, } + + @staticmethod + def _fix_incompatible_config(vllm_config: VllmConfig) -> None: + """ + Check and correct parameters in VllmConfig that are incompatible with Ascend NPU. + If GPU-specific or currently unsupported parameters are set by the user, + log a warning and reset them to safe values. + """ + # ==================== 1. Model Config ==================== + if vllm_config.model_config: + # Disable Cascade Attention (GPU feature) + if getattr(vllm_config.model_config, "disable_cascade_attn", False): + logger.warning( + "Parameter '--disable-cascade-attn' is a GPU-specific feature. Resetting to False for Ascend." + ) + vllm_config.model_config.disable_cascade_attn = False + + # ==================== 2. Parallel Config ==================== + if vllm_config.parallel_config: + # Only allow the default all2all backend; others like deepep are not supported + default_backend = "allgather_reducescatter" + current_backend = getattr(vllm_config.parallel_config, "all2all_backend", default_backend) + if current_backend != default_backend: + logger.warning( + "Parameter '--all2all-backend' is set to '%s', which may be " + "incompatible with Ascend. Using internal plugin mechanisms.", + current_backend, + ) + vllm_config.parallel_config.all2all_backend = default_backend + + # ==================== 3. Cache Config ==================== + # Check and reset cpu_kvcache_space_bytes + if getattr(vllm_config.cache_config, "cpu_kvcache_space_bytes", False): + logger.warning( + "Parameter 'cpu_kvcache_space_bytes' is tied to cpu backend. Resetting to None for Ascend." + ) + vllm_config.cache_config.cpu_kvcache_space_bytes = None + + # ==================== 4. MultiModal Config ==================== + if vllm_config.model_config.multimodal_config: + # Ascend uses a different mechanism for Multi-Modal attention + if getattr(vllm_config.model_config.multimodal_config, "mm_encoder_attn_backend", None) is not None: + logger.warning( + "Parameter '--mm-encoder-attn-backend' is set but Ascend uses " + "a plugin mechanism for multi-modal attention. Resetting to None." + ) + vllm_config.model_config.multimodal_config.mm_encoder_attn_backend = None + + # ==================== 5. Observability Config ==================== + if vllm_config.observability_config: + # NVTX tracing is NVIDIA specific + if getattr(vllm_config.observability_config, "enable_layerwise_nvtx_tracing", False): + logger.warning( + "Parameter '--enable-layerwise-nvtx-tracing' relies on NVTX " + "(NVIDIA Tools) and is not supported on Ascend. Resetting to False." + ) + vllm_config.observability_config.enable_layerwise_nvtx_tracing = False + + # ==================== 6. Scheduler Config ==================== + if vllm_config.scheduler_config: + # Partial prefills are specific to ROCm optimization + if getattr(vllm_config.scheduler_config, "max_num_partial_prefills", 1) != 1: + logger.warning( + "Parameter '--max-num-partial-prefills' is optimized for ROCm. Resetting to default (1) for Ascend." + ) + vllm_config.scheduler_config.max_num_partial_prefills = 1 + + # ==================== 7. Speculative Config ==================== + if vllm_config.speculative_config: + # Ascend automatically inherits main model quantization + if getattr(vllm_config.speculative_config, "quantization", None) is not None: + logger.warning( + "Speculative quantization is set but Ascend automatically uses " + "the main model's quantization method. Resetting to None." + ) + vllm_config.speculative_config.quantization = None + + # ==================== 8. KV Transfer Config ==================== + if vllm_config.kv_transfer_config: + # Buffer size is primarily tied to NCCL (GPU) backends + current_buffer_size = getattr(vllm_config.kv_transfer_config, "kv_buffer_size", 1e9) + if current_buffer_size != 1e9: + logger.warning( + "Parameter 'kv_buffer_size' is optimized for NCCL and may be " + "incompatible with current Ascend KV transfer status. Resetting to default (1e9)." + ) + # Use setattr to safely assign the value + vllm_config.kv_transfer_config.kv_buffer_size = 1e9 + + # Check and reset enable_permute_local_kv + if getattr(vllm_config.kv_transfer_config, "enable_permute_local_kv", False): + logger.warning( + "Parameter 'enable_permute_local_kv' is tied to NIXL backend. " + "Resetting to False for Ascend stability." + ) + vllm_config.kv_transfer_config.enable_permute_local_kv = False + + # ==================== 9. Attention Config ==================== + if vllm_config.attention_config: + att_config = vllm_config.attention_config + + # Boolean flags that must be False on Ascend (typically NVIDIA-specific) + force_false_flags = [ + "use_prefill_decode_attention", + "use_cudnn_prefill", + "use_trtllm_ragged_deepseek_prefill", + "use_trtllm_attention", + "disable_flashinfer_prefill", + "disable_flashinfer_q_quantization", + ] + for flag in force_false_flags: + if getattr(att_config, flag, False): + logger.warning( + "Ignored parameter '%s'. This is a GPU-specific feature " + "not supported on Ascend. Resetting to False.", + flag, + ) + setattr(att_config, flag, False) + + # Reset specific values to None as Ascend uses its own internal logic + if getattr(att_config, "flash_attn_version", None) is not None: + logger.warning( + "Ignored parameter 'flash_attn_version'. Ascend uses its own attention backend. Resetting to None." + ) + att_config.flash_attn_version = None + + # Notify user that the backend will be managed by Ascend plugins + if getattr(att_config, "backend", None) is not None: + logger.info( + "User specified attention backend '%s'. Note that Ascend NPU " + "will use its registered plugin backend instead. Resetting to None.", + att_config.backend, + ) + att_config.backend = None + + # CUDA Graph specific split points are not applicable + if getattr(att_config, "flash_attn_max_num_splits_for_cuda_graph", 32) != 32: + logger.warning( + "Parameter 'flash_attn_max_num_splits_for_cuda_graph' is " + "ignored on Ascend. Resetting to default (32)." + ) + att_config.flash_attn_max_num_splits_for_cuda_graph = 32