Files
xc-llm-ascend/tests/ut/test_platform.py
Mengqing Cao 8cfd257992 [Dist][EP] Remove ETP/EP maintained in vllm-ascend (#1681)
### What this PR does / why we need it?
Remove ETP/EP maintained in branch main. We drop this as there is no
relevant scenarios to use ETP now, and we may subsequently advocate
implementing expert tensor parallelism in vLLM to support scenarios
where the expert is needed to be sliced

This is a part of #1422 backport.

Fixes https://github.com/vllm-project/vllm-ascend/issues/1396
https://github.com/vllm-project/vllm-ascend/issues/1154

### Does this PR introduce _any_ user-facing change?
We'll not maintain etp/ep in vllm-ascend anymore, and use the tp/ep in
vllm instead.

### How was this patch tested?
CI passed with new added and existing test.


- vLLM version: v0.9.2
- vLLM main:
fe8a2c544a

Signed-off-by: MengqingCao <cmq0113@163.com>
2025-07-21 09:08:04 +08:00

589 lines
25 KiB
Python

import importlib
import unittest
from datetime import timedelta
from unittest.mock import MagicMock, patch
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.config import CompilationLevel
from vllm.platforms import PlatformEnum
from tests.ut.base import TestBase
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD
class TestNPUPlatform(TestBase):
def setUp(self):
self.platform = NPUPlatform()
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.compilation_config = MagicMock()
self.mock_vllm_config.model_config = MagicMock()
self.mock_vllm_config.parallel_config = MagicMock()
self.mock_vllm_config.cache_config = MagicMock()
self.mock_vllm_config.scheduler_config = MagicMock()
self.mock_vllm_config.speculative_config = None
self.mock_ascend_config = MagicMock()
self.mock_ascend_config.torchair_graph_config.enabled = False
self.mock_ascend_config.ascend_scheduler_config.enabled = False
def test_class_variables(self):
self.assertEqual(NPUPlatform._enum, PlatformEnum.OOT)
self.assertEqual(NPUPlatform.device_name, "npu")
self.assertEqual(NPUPlatform.device_type, "npu")
self.assertEqual(NPUPlatform.simple_compile_backend, "eager")
self.assertEqual(NPUPlatform.ray_device_key, "NPU")
self.assertEqual(NPUPlatform.device_control_env_var,
"ASCEND_RT_VISIBLE_DEVICES")
self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1")
self.assertEqual(NPUPlatform.supported_quantization,
[ASCEND_QUATIZATION_METHOD])
def test_is_sleep_mode_available(self):
self.assertTrue(self.platform.is_sleep_mode_available())
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_parser(self, mock_quant_config,
mock_adapt_patch):
mock_parser = MagicMock()
mock_action = MagicMock()
mock_action.choices = ["awq", "gptq"]
mock_parser._option_string_actions = {"--quantization": mock_action}
self.platform.pre_register_and_update(mock_parser)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
self.assertTrue(ASCEND_QUATIZATION_METHOD in mock_action.choices)
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_without_parser(self, mock_quant_config,
mock_adapt_patch):
self.platform.pre_register_and_update(None)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_parser_no_quant_action(
self, mock_quant_config, mock_adapt_patch):
mock_parser = MagicMock()
mock_parser._option_string_actions = {}
self.platform.pre_register_and_update(mock_parser)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_existing_ascend_quant(
self, mock_quant_config, mock_adapt_patch):
mock_parser = MagicMock()
mock_action = MagicMock()
mock_action.choices = ["awq", ASCEND_QUATIZATION_METHOD]
mock_parser._option_string_actions = {"--quantization": mock_action}
self.platform.pre_register_and_update(mock_parser)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
self.assertEqual(len(mock_action.choices), 2)
def test_get_device_capability(self):
self.assertIsNone(self.platform.get_device_capability(device_id=0))
@patch("torch.npu.get_device_name")
def test_get_device_name(self, mock_get_device_name):
device_id = 0
device_name = "Ascend910B2"
mock_get_device_name.return_value = device_name
self.assertEqual(self.platform.get_device_name(device_id), device_name)
mock_get_device_name.assert_called_once_with(0)
def test_is_async_output_supported(self):
self.assertTrue(
self.platform.is_async_output_supported(enforce_eager=None))
self.assertTrue(
self.platform.is_async_output_supported(enforce_eager=True))
self.assertTrue(
self.platform.is_async_output_supported(enforce_eager=False))
@patch("torch.inference_mode")
def test_inference_mode(self, mock_inference_mode):
mock_inference_mode.return_value = None
self.assertIsNone(self.platform.inference_mode())
mock_inference_mode.assert_called_once()
@patch("torch.npu.set_device")
def test_set_device_normal(self, mock_set_device):
device = torch.device("npu:0")
self.platform.set_device(device)
mock_set_device.assert_called_once_with(device)
@patch("torch.npu.set_device",
side_effect=RuntimeError("Device not available"))
def test_set_device_failure(self, mock_set_device):
device = torch.device("npu:0")
with self.assertRaises(RuntimeError):
self.platform.set_device(device)
mock_set_device.assert_called_once_with(device)
@patch("torch.npu.empty_cache")
def test_empty_cache_normal(self, mock_empty_cache):
self.platform.empty_cache()
mock_empty_cache.assert_called_once()
@patch("torch.npu.empty_cache",
side_effect=RuntimeError("Cache clearing failed"))
def test_empty_cache_failure(self, mock_empty_cache):
with self.assertRaises(RuntimeError):
self.platform.empty_cache()
mock_empty_cache.assert_called_once()
@patch("torch.npu.synchronize")
def test_synchronize_normal(self, mock_synchronize):
self.platform.synchronize()
mock_synchronize.assert_called_once()
@patch("torch.npu.synchronize",
side_effect=RuntimeError("Synchronization failed"))
def test_synchronize_failure(self, mock_synchronize):
with self.assertRaises(RuntimeError):
self.platform.synchronize()
mock_synchronize.assert_called_once()
@patch("torch.npu.mem_get_info")
def test_mem_get_info_normal(self, mock_mem_get_info):
free_memory_size = 1024
total_memory_size = 2048
memory_info = (free_memory_size, total_memory_size)
mock_mem_get_info.return_value = memory_info
result = self.platform.mem_get_info()
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)
self.assertEqual(result, memory_info)
mock_mem_get_info.assert_called_once()
@patch("torch.npu.mem_get_info",
side_effect=RuntimeError("NPU not available"))
def test_mem_get_info_failure(self, mock_mem_get_info):
with self.assertRaises(RuntimeError):
self.platform.mem_get_info()
mock_mem_get_info.assert_called_once()
@patch("gc.collect")
@patch("torch.npu.empty_cache")
@patch("torch.npu.reset_peak_memory_stats")
def test_clear_npu_memory_normal(self, mock_reset_stats, mock_empty_cache,
mock_gc_collect):
self.platform.clear_npu_memory()
mock_gc_collect.assert_called_once()
mock_empty_cache.assert_called_once()
mock_reset_stats.assert_called_once()
@patch("gc.collect", side_effect=Exception("GC failed"))
@patch("torch.npu.empty_cache")
@patch("torch.npu.reset_peak_memory_stats")
def test_clear_npu_memory_gc_collect_failure(self, mock_reset_stats,
mock_empty_cache,
mock_gc_collect):
with self.assertRaises(Exception):
self.platform.clear_npu_memory()
mock_gc_collect.assert_called_once()
mock_empty_cache.assert_not_called()
mock_reset_stats.assert_not_called()
@patch("gc.collect")
@patch("torch.npu.empty_cache",
side_effect=RuntimeError("Cache clear failed"))
@patch("torch.npu.reset_peak_memory_stats")
def test_clear_npu_memory_empty_cache_failure(self, mock_reset_stats,
mock_empty_cache,
mock_gc_collect):
with self.assertRaises(RuntimeError):
self.platform.clear_npu_memory()
mock_gc_collect.assert_called_once()
mock_empty_cache.assert_called_once()
mock_reset_stats.assert_not_called()
@patch("gc.collect")
@patch("torch.npu.empty_cache")
@patch("torch.npu.reset_peak_memory_stats",
side_effect=RuntimeError("Reset failed"))
def test_clear_npu_memory_reset_stats_failure(self, mock_reset_stats,
mock_empty_cache,
mock_gc_collect):
with self.assertRaises(RuntimeError):
self.platform.clear_npu_memory()
mock_gc_collect.assert_called_once()
mock_empty_cache.assert_called_once()
mock_reset_stats.assert_called_once()
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.update_aclgraph_sizes")
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("os.environ", {})
def test_check_and_update_config_basic_config_update(
self, mock_is_310p, mock_update_acl, mock_init_ascend,
mock_check_ascend):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.parallel_config.enable_expert_parallel = False
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
# Without this reload, when calling self.platform.check_and_update_config,
# it would execute the original unmocked init_ascend_config method, causing the unit test to fail.
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
mock_init_ascend.assert_called_once_with(self.mock_vllm_config)
mock_check_ascend.assert_called_once()
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_no_model_config_warning(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.model_config = None
with self.assertLogs(logger="vllm", level="WARNING") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertTrue("Model config is missing" in cm.output[0])
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_enforce_eager_mode(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.model_config.enforce_eager = True
with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertTrue("Compilation disabled, using eager mode by default" in
cm.output[0])
self.assertEqual(
self.mock_vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_unsupported_compilation_level(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.model_config.enforce_eager = False
self.mock_vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
with self.assertLogs(logger="vllm", level="WARNING") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertTrue("NPU does not support" in cm.output[0])
self.assertEqual(
self.mock_vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_torchair_enabled_compilation(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
self.mock_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.model_config.enforce_eager = False
self.mock_vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertTrue("Torchair compilation enabled" in cm.output[0])
self.assertEqual(
self.mock_vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_cache_config_block_size(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.cache_config.block_size = None
self.mock_vllm_config.cache_config.enable_prefix_caching = True
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertEqual(self.mock_vllm_config.cache_config.block_size, 128)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_v1_worker_class_selection(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.parallel_config.worker_cls = "auto"
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertEqual(
self.mock_vllm_config.parallel_config.worker_cls,
"vllm_ascend.worker.worker_v1.NPUWorker",
)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.is_310p", return_value=True)
def test_check_and_update_config_310p_no_custom_ops(
self, mock_is_310p, mock_init_ascend, mock_check_ascend):
mock_init_ascend.return_value = self.mock_ascend_config
self.mock_vllm_config.compilation_config.custom_ops = []
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
self.assertEqual(self.mock_vllm_config.compilation_config.custom_ops,
[])
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_ascend_scheduler_config(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
self.mock_ascend_config.ascend_scheduler_config.enabled = True
mock_init_ascend.return_value = self.mock_ascend_config
with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig"
) as mock_scheduler:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(self.mock_vllm_config)
mock_scheduler.initialize_from_config.assert_called_once()
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=True,
)
self.assertEqual(result,
"vllm_ascend.attention.mla_v1.AscendMLABackend")
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_torchair(self,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=False,
)
self.assertEqual(
result,
"vllm_ascend.attention.attention_v1_torchair.AscendAttentionTorchairBackend"
)
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=False,
)
self.assertEqual(
result,
"vllm_ascend.attention.attention_v1.AscendAttentionBackend")
def test_get_punica_wrapper(self):
result = self.platform.get_punica_wrapper()
self.assertEqual(
result,
"vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU")
@patch("torch.npu.reset_peak_memory_stats")
@patch("torch.npu.max_memory_allocated")
def test_get_current_memory_usage_with_specific_device(
self, mock_max_memory, mock_reset_stats):
max_memory_allocated_result = 1024.0
mock_max_memory.return_value = max_memory_allocated_result
test_device = torch.device("npu:0")
result = self.platform.get_current_memory_usage(device=test_device)
mock_reset_stats.assert_called_once_with(test_device)
mock_max_memory.assert_called_once_with(test_device)
self.assertEqual(result, max_memory_allocated_result)
@patch("torch.npu.reset_peak_memory_stats")
@patch("torch.npu.max_memory_allocated")
def test_get_current_memory_usage_with_default_device(
self, mock_max_memory, mock_reset_stats):
max_memory_allocated_result = 1024.0
mock_max_memory.return_value = max_memory_allocated_result
result = self.platform.get_current_memory_usage()
mock_reset_stats.assert_called_once_with(None)
mock_max_memory.assert_called_once_with(None)
self.assertEqual(result, max_memory_allocated_result)
@patch("torch.npu.reset_peak_memory_stats",
side_effect=RuntimeError("Device error"))
@patch("torch.npu.max_memory_allocated")
def test_get_current_memory_usage_when_reset_stats_fails(
self, mock_max_memory, mock_reset_stats):
with self.assertRaises(RuntimeError):
self.platform.get_current_memory_usage()
mock_reset_stats.assert_called_once()
mock_max_memory.assert_not_called()
@patch("torch.npu.reset_peak_memory_stats")
@patch(
"torch.npu.max_memory_allocated",
side_effect=RuntimeError("Memory query failed"),
)
def test_get_current_memory_usage_when_query_fails(self, mock_max_memory,
mock_reset_stats):
with self.assertRaises(RuntimeError):
self.platform.get_current_memory_usage()
mock_reset_stats.assert_called_once()
mock_max_memory.assert_called_once()
def test_get_device_communicator_cls_returns_correct_value(self):
self.assertEqual(
self.platform.get_device_communicator_cls(),
"vllm_ascend.distributed.communicator.NPUCommunicator",
)
def test_is_pin_memory_available_returns_true(self):
self.assertTrue(self.platform.is_pin_memory_available())
def test_supports_v1(self):
from vllm.config import ModelConfig
mock_config = MagicMock(spec=ModelConfig)
self.assertTrue(self.platform.supports_v1(mock_config))
def test_get_piecewise_backend_cls_returns_correct_value(self):
self.assertEqual(
self.platform.get_piecewise_backend_cls(),
"vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend",
)
@patch("torch.distributed.is_hccl_available", return_value=True)
@patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL")
@patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options")
@patch("torch.distributed.ProcessGroup")
def test_successful_initialization(self, mock_pg, mock_options_cls,
mock_pg_hccl, _):
mock_prefix = MagicMock(spec=PrefixStore)
mock_options = MagicMock(spec=ProcessGroup.Options)
mock_options_cls.return_value = mock_options
mock_backend = MagicMock()
mock_pg_hccl.return_value = mock_backend
group_rank = 0
group_size = 4
mock_pg_instance = MagicMock(spec=ProcessGroup)
mock_pg.return_value = mock_pg_instance
# Use importlib.reload() to force-reload the platform module and ensure the mocked ProcessGroup is used.
# Without this reload, when executing self.platform.stateless_init_device_torch_dist_pg(),
# it would invoke the original unmocked ProcessGroup implementation instead of our test mock,
# which would cause the unit test to fail.
from vllm_ascend import platform
importlib.reload(platform)
result = self.platform.stateless_init_device_torch_dist_pg(
backend="hccl",
prefix_store=mock_prefix,
group_rank=group_rank,
group_size=group_size,
timeout=timedelta(seconds=30),
)
mock_pg.assert_called_once_with(mock_prefix, group_rank, group_size,
unittest.mock.ANY)
mock_pg_hccl.assert_called_once_with(mock_prefix, group_rank,
group_size, unittest.mock.ANY)
mock_backend._set_sequence_number_for_group.assert_called_once()
mock_pg_instance._register_backend.assert_called_once_with(
torch.device("npu"), unittest.mock.ANY, mock_backend)
self.assertEqual(result, mock_pg_instance)
@patch("torch.distributed.is_hccl_available", return_value=False)
def test_hccl_unavailable(self, _):
with self.assertRaises(AssertionError):
from vllm_ascend import platform
importlib.reload(platform)
self.platform.stateless_init_device_torch_dist_pg(
backend="hccl",
prefix_store=MagicMock(),
group_rank=0,
group_size=4,
timeout=timedelta(seconds=30),
)