[Build] Add build info (#1386)

Add static build_info py file to show soc and sleep mode info. It helps
to make the code clean and the error info will be more friendly for
users

This PR also added the unit test for vllm_ascend/utils.py

This PR also added the base test class for all ut in tests/ut/base.py

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-06-27 09:14:43 +08:00
committed by GitHub
parent c563a08f0a
commit 5968dff4e0
11 changed files with 388 additions and 66 deletions

2
.gitignore vendored
View File

@@ -196,3 +196,5 @@ kernel_meta/
# version file generated by setuptools-scm
/vllm_ascend/_version.py
# build info file generated by setup.py
/vllm_ascend/_build_info.py

View File

@@ -27,6 +27,7 @@ from typing import Dict, List
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
from setuptools.command.install import install
from setuptools_scm import get_version
@@ -78,6 +79,30 @@ class CMakeExtension(Extension):
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
class custom_build_info(build_py):
def run(self):
soc_version = envs.SOC_VERSION
if not soc_version:
raise ValueError(
"SOC version is not set. Please set SOC_VERSION environment variable."
)
if "310" in soc_version and not envs.COMPILE_CUSTOM_KERNELS:
raise ValueError(
"SOC version 310 only supports custom kernels. Please set COMPILE_CUSTOM_KERNELS=1 to enable custom kernels."
)
package_dir = os.path.join(ROOT_DIR, "vllm_ascend", "_build_info.py")
with open(package_dir, "w+") as f:
f.write('# Auto-generated file\n')
f.write(f"__soc_version__ = '{soc_version}'\n")
f.write(
f"__sleep_mode_enabled__ = {envs.COMPILE_CUSTOM_KERNELS}\n")
logging.info(
f"Generated _build_info.py with SOC version: {soc_version}")
super().run()
class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config: Dict[str, bool] = {}
@@ -326,7 +351,11 @@ def get_requirements() -> List[str]:
return requirements
cmdclass = {"build_ext": cmake_build_ext, "install": custom_install}
cmdclass = {
"build_py": custom_build_info,
"build_ext": cmake_build_ext,
"install": custom_install
}
setup(
name="vllm_ascend",

12
tests/ut/base.py Normal file
View File

@@ -0,0 +1,12 @@
import unittest
from vllm_ascend.utils import adapt_patch
class TestBase(unittest.TestCase):
def setUp(self):
# adapt patch by default.
adapt_patch(True)
adapt_patch()
super().setUp()

View File

@@ -0,0 +1,12 @@
from tests.ut.base import TestBase
class TestPatchDistributed(TestBase):
def test_GroupCoordinator_patched(self):
from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
GroupCoordinatorPatch
self.assertIs(GroupCoordinator, GroupCoordinatorPatch)

251
tests/ut/test_utils.py Normal file
View File

@@ -0,0 +1,251 @@
import math
import os
import unittest
from threading import Lock
from unittest import mock
import torch
from vllm.config import (CompilationConfig, ModelConfig, ParallelConfig,
VllmConfig)
from vllm_ascend import utils
class TestUtils(unittest.TestCase):
def test_is_310p(self):
utils._IS_310P = None
with mock.patch("vllm_ascend._build_info.__soc_version__",
"Ascend310P3"):
self.assertTrue(utils.is_310p())
utils._IS_310P = None
with mock.patch("vllm_ascend._build_info.__soc_version__",
"Ascend910P1"):
self.assertFalse(utils.is_310p())
def test_sleep_mode_enabled(self):
utils._SLEEP_MODE_ENABLED = None
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
True):
self.assertTrue(utils.sleep_mode_enabled())
utils._SLEEP_MODE_ENABLED = None
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
False):
self.assertFalse(utils.sleep_mode_enabled())
def test_nd_to_nz_2d(self):
# can be divided by 16
input_tensor = torch.randn(32, 64)
output = utils.nd_to_nz_2d(input_tensor)
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], 64 // 16)
self.assertEqual(output.shape[2], 32)
self.assertEqual(output.shape[3], 16)
# cannot be divided by 16
input_tensor = torch.randn(30, 62)
output = utils.nd_to_nz_2d(input_tensor)
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], math.ceil(62 / 16))
self.assertEqual(output.shape[2], 32)
self.assertEqual(output.shape[3], 16)
# pad to 16
input_tensor = torch.randn(8, 12)
output = utils.nd_to_nz_2d(input_tensor)
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], 1) # 12->16, 16//16=1
self.assertEqual(output.shape[2], 16) # 8->16
self.assertEqual(output.shape[3], 16)
# check if the output is contiguous
input_tensor = torch.randn(32, 64)
output = utils.nd_to_nz_2d(input_tensor)
self.assertTrue(output.is_contiguous())
# check if the output values are preserved
input_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
output = utils.nd_to_nz_2d(input_tensor)
expected = torch.tensor(
[[[[1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]])
self.assertTrue(torch.allclose(output, expected))
def test_aligned_16(self):
# align to 16
input_tensor = torch.randn(15, 64)
output_tensor = utils.aligned_16(input_tensor)
self.assertEqual(output_tensor.shape[0], 16)
# align to 16
input_tensor = torch.randn(16, 64)
output_tensor = utils.aligned_16(input_tensor)
self.assertEqual(output_tensor.shape[0], 16)
self.assertTrue(torch.equal(input_tensor, output_tensor))
# align to 32
input_tensor = torch.randn(17, 64)
output_tensor = utils.aligned_16(input_tensor)
self.assertEqual(output_tensor.shape[0], 32)
@mock.patch('importlib.util.find_spec')
@mock.patch('importlib.import_module')
def test_try_register_lib(self, mock_import_module, mock_find_spec):
# import OK
mock_find_spec.return_value = mock.MagicMock()
mock_import_module.return_value = mock.MagicMock()
lib_name = "existing_lib"
lib_info = "Library found and imported successfully"
utils.try_register_lib(lib_name, lib_info)
mock_find_spec.assert_called_once_with(lib_name)
mock_import_module.assert_called_once_with(lib_name)
# Can't find lib
mock_find_spec.return_value = None
lib_name = "non_existing_lib"
utils.try_register_lib(lib_name)
self.assertEqual(2, mock_find_spec.call_count)
self.assertEqual(1, mock_import_module.call_count)
# import error
mock_find_spec.return_value = mock.MagicMock()
mock_import_module.side_effect = ImportError("import error")
lib_name = "error_lib"
utils.try_register_lib(lib_name)
self.assertEqual(3, mock_find_spec.call_count)
self.assertEqual(2, mock_import_module.call_count)
def test_enable_custom_op(self):
result = utils.enable_custom_op()
self.assertTrue(result)
utils._CUSTOM_OP_ENABLED = None
with mock.patch('builtins.__import__') as mock_import_module:
mock_import_module.side_effect = ImportError("import error")
self.assertFalse(utils.enable_custom_op())
def test_find_hccl_library(self):
with mock.patch.dict(os.environ,
{"HCCL_SO_PATH": "/path/to/hccl/libhccl.so"}):
self.assertEqual(utils.find_hccl_library(),
"/path/to/hccl/libhccl.so")
with mock.patch("torch.version.cann", None):
self.assertRaises(ValueError, utils.find_hccl_library)
with mock.patch("torch.version.cann", "Ascend910"):
self.assertEqual(utils.find_hccl_library(), "libhccl.so")
def test_current_stream(self):
with mock.patch("torch.npu.current_stream") as mock_current_stream:
self.assertEqual(utils.current_stream(), mock_current_stream())
def test_vllm_version_is(self):
with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}):
with mock.patch("vllm.__version__", "1.0.0"):
self.assertTrue(utils.vllm_version_is("1.0.0"))
self.assertFalse(utils.vllm_version_is("2.0.0"))
with mock.patch("vllm.__version__", "2.0.0"):
self.assertTrue(utils.vllm_version_is("1.0.0"))
self.assertFalse(utils.vllm_version_is("2.0.0"))
with mock.patch("vllm.__version__", "1.0.0"):
self.assertTrue(utils.vllm_version_is("1.0.0"))
self.assertFalse(utils.vllm_version_is("2.0.0"))
with mock.patch("vllm.__version__", "2.0.0"):
self.assertTrue(utils.vllm_version_is("2.0.0"))
self.assertFalse(utils.vllm_version_is("1.0.0"))
def test_update_aclgraph_sizes(self):
# max_num_batch_sizes < len(original_sizes)
test_compilation_config = CompilationConfig(
cudagraph_capture_sizes=[i for i in range(150)])
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
test_model_config = ModelConfig(model=model_path, enforce_eager=True)
test_parallel_config = ParallelConfig()
test_vllm_config = VllmConfig(
model_config=test_model_config,
compilation_config=test_compilation_config,
parallel_config=test_parallel_config,
)
utils.update_aclgraph_sizes(test_vllm_config)
self.assertEqual(
147,
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
# max_num_batch_sizes >= len(original_sizes)
test_compilation_config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 3])
test_vllm_config = VllmConfig(
model_config=test_model_config,
compilation_config=test_compilation_config,
parallel_config=test_parallel_config,
)
utils.update_aclgraph_sizes(test_vllm_config)
self.assertEqual(
3,
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
class TestProfileExecuteDuration(unittest.TestCase):
def setUp(self):
utils.ProfileExecuteDuration._instance = None
utils.ProfileExecuteDuration._observations = []
utils.ProfileExecuteDuration._lock = Lock()
def test_singleton_creation(self):
instance1 = utils.ProfileExecuteDuration()
self.assertIsNotNone(instance1)
self.assertIs(instance1, utils.ProfileExecuteDuration._instance)
instance2 = utils.ProfileExecuteDuration()
self.assertIs(instance1, instance2)
def test_thread_safety(self):
from threading import Thread
instances = []
def create_instance():
instances.append(utils.ProfileExecuteDuration())
threads = [Thread(target=create_instance) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
first_instance = instances[0]
for instance in instances[1:]:
self.assertIs(first_instance, instance)
def test_atexit_registration(self):
with mock.patch('atexit.register') as mock_register:
instance = utils.ProfileExecuteDuration()
mock_register.assert_called_once_with(instance.destroy)
def test_lock_usage(self):
original_lock = utils.ProfileExecuteDuration._lock
with mock.patch.object(utils.ProfileExecuteDuration,
'_lock',
wraps=original_lock) as mock_lock:
utils.ProfileExecuteDuration()
mock_lock.__enter__.assert_called()
mock_lock.__exit__.assert_called()
def test_observations_initialization(self):
instance = utils.ProfileExecuteDuration()
self.assertEqual(instance._observations, [])

View File

@@ -138,7 +138,6 @@ class CaMemAllocator:
We cannot call the constructor directly.
Call this method to get the instance.
"""
assert camem_available, "camem allocator is not available"
if CaMemAllocator.instance is None:
CaMemAllocator.instance = CaMemAllocator()
return CaMemAllocator.instance

View File

@@ -155,7 +155,6 @@ def fused_experts_with_mc2(
kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]

View File

@@ -23,7 +23,7 @@ import vllm.distributed
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm_ascend.utils import NullHandle, is_310p
from vllm_ascend.utils import is_310p
def ascend_destroy_model_parallel():
@@ -66,6 +66,15 @@ vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_pa
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
class NullHandle:
def __init__(self):
pass
def wait(self):
pass
def communication_adaptation_310p():
def broadcast310p(tensor, src, group=None, async_op=False):

View File

@@ -21,13 +21,11 @@ import atexit
import math
from contextlib import contextmanager, nullcontext
from enum import Enum
from functools import lru_cache
from threading import Lock
from typing import TYPE_CHECKING, List, Tuple
import torch
import torch_npu # noqa: F401 # noqa: F401
import torchair # type: ignore[import] # noqa: F401
from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event
from vllm.logger import logger
@@ -55,75 +53,83 @@ else:
MAX_CAPTURE_SIZE = 1920
ASCEND_QUATIZATION_METHOD = "ascend"
CUSTOM_OP_ENABLED = None
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
ACL_FORMAT_FRACTAL_ND = 2
ACL_FORMAT_FRACTAL_NZ = 29
@lru_cache(maxsize=None)
def _get_soc_version():
"""Gets the SOC version and caches it."""
if not torch.npu.is_available():
return ""
device_count = torch.npu.device_count()
if device_count <= 0:
return ""
try:
return torch.npu.get_device_name(0)
except Exception:
return ""
_SOC_VERSION = _get_soc_version()
_CUSTOM_OP_ENABLED = None
_IS_310P = None
_SLEEP_MODE_ENABLED = None
_CURRENT_STREAM = None
def is_310p():
return _SOC_VERSION in SOC_VERSION_INFERENCE_SERIES
global _IS_310P
if _IS_310P is None:
from vllm_ascend import _build_info # type: ignore
_IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
return _IS_310P
class NullHandle:
def __init__(self):
pass
def wait(self):
pass
def sleep_mode_enabled():
global _SLEEP_MODE_ENABLED
if _SLEEP_MODE_ENABLED is None:
from vllm_ascend import _build_info # type: ignore
_SLEEP_MODE_ENABLED = _build_info.__sleep_mode_enabled__
return _SLEEP_MODE_ENABLED
def _round_up(x: int, align: int):
if align == 0:
return -1
# round up x to align, for example, if align is 16, x will be rounded up to 16, 32, 48, etc.
# input: 15, 16 -> output: 16
# input: 17, 16 -> output: 32
# input: 30, 16 -> output: 32
# input: 33, 16 -> output: 48
# ...
return (x + align - 1) // align * align
def _custom_pad(x, pad_dims):
# pad the input tensor to the shape of pad_dims
# input: (13, 30), pad_dims: [0, 2, 0, 3]
# output: (16, 32)
return torch.nn.functional.pad(x, pad_dims)
def _custom_reshape(x, target_shape):
# reshape the input tensor to the shape of target_shape
# input: (16, 32), target_shape: [1, 16, 2, 16]
# output: (1, 16, 2, 16)
return x.reshape(target_shape)
def _custom_transpose(x, dim1, dim2):
# transpose the input tensor
# input: (1, 16, 2, 16), dim1: 1, dim2: 2
# output: (1, 2, 16, 16)
return x.transpose(dim1, dim2)
def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
aux_dims = [0, 0, 0, 0]
aux_dims[0] = 1
# in_tensor: (13, 30)
aux_dims = [1, 0, 0, 16]
# aux_dims[1]: 16
aux_dims[1] = _round_up(in_tensor.size(0), 16)
# aux_dims[2]: 2
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
# after: aux_dims: [1, 16, 2, 16]
pad_dims = [0, 0, 0, 0]
# pad_dims[1]: 2
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
# pad_dims[3]: 3
pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
aux_dims[3] = 16
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
# after: pad_dims: [0, 2, 0, 3]
# return: (1, 2, 16, 16)
return _custom_transpose(
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
2).contiguous()
@@ -187,24 +193,19 @@ def enable_custom_op():
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
"""
global CUSTOM_OP_ENABLED
if CUSTOM_OP_ENABLED is not None:
return CUSTOM_OP_ENABLED
else:
try:
# register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
CUSTOM_OP_ENABLED = True
except ImportError:
CUSTOM_OP_ENABLED = False
logger.warning(
"Warning: Failed to register custom ops, all custom ops will be disabled"
)
return CUSTOM_OP_ENABLED
global _CUSTOM_OP_ENABLED
if _CUSTOM_OP_ENABLED is not None:
return _CUSTOM_OP_ENABLED
try:
# register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
_CUSTOM_OP_ENABLED = True
except ImportError:
_CUSTOM_OP_ENABLED = False
logger.warning(
"Warning: Failed to register custom ops, all custom ops will be disabled"
)
return _CUSTOM_OP_ENABLED
def find_hccl_library() -> str:
@@ -229,9 +230,6 @@ def find_hccl_library() -> str:
return so_file
_current_stream = None
def current_stream() -> torch.npu.Stream:
"""
replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
@@ -241,12 +239,12 @@ def current_stream() -> torch.npu.Stream:
directly, so that we can avoid calling `torch.npu.current_stream()`.
"""
global _current_stream
if _current_stream is None:
global _CURRENT_STREAM
if _CURRENT_STREAM is None:
# when this function is called before any stream is set,
# we return the default stream.
_current_stream = torch.npu.current_stream()
return _current_stream
_CURRENT_STREAM = torch.npu.current_stream()
return _CURRENT_STREAM
def adapt_patch(is_global_patch: bool = False):
@@ -326,6 +324,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
len(original_sizes))
# TODO(wxy): Move to ops module
def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
@@ -378,10 +377,12 @@ class ProfileExecuteDuration:
return durations
# TODO(wxy): Move to ops module
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
# TODO(wxy): Move to ops module
def npu_wait_tensor(self: torch.Tensor,
dependency: torch.Tensor,
*,

View File

@@ -40,7 +40,7 @@ from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import try_register_lib
from vllm_ascend.utils import sleep_mode_enabled, try_register_lib
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -91,6 +91,10 @@ class NPUWorker(WorkerBase):
self.profiler = self._init_profiler()
def sleep(self, level: int = 1) -> None:
if not sleep_mode_enabled():
raise ValueError(
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
)
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
allocator = CaMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
@@ -104,6 +108,10 @@ class NPUWorker(WorkerBase):
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
if not sleep_mode_enabled():
raise ValueError(
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
)
allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags)