[Build] Add build info (#1386)
Add static build_info py file to show soc and sleep mode info. It helps to make the code clean and the error info will be more friendly for users This PR also added the unit test for vllm_ascend/utils.py This PR also added the base test class for all ut in tests/ut/base.py Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -21,13 +21,11 @@ import atexit
|
||||
import math
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401 # noqa: F401
|
||||
import torchair # type: ignore[import] # noqa: F401
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from torch_npu.npu.streams import Event
|
||||
from vllm.logger import logger
|
||||
@@ -55,75 +53,83 @@ else:
|
||||
MAX_CAPTURE_SIZE = 1920
|
||||
|
||||
ASCEND_QUATIZATION_METHOD = "ascend"
|
||||
|
||||
CUSTOM_OP_ENABLED = None
|
||||
|
||||
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
||||
|
||||
ACL_FORMAT_FRACTAL_ND = 2
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_soc_version():
|
||||
"""Gets the SOC version and caches it."""
|
||||
if not torch.npu.is_available():
|
||||
return ""
|
||||
device_count = torch.npu.device_count()
|
||||
if device_count <= 0:
|
||||
return ""
|
||||
try:
|
||||
return torch.npu.get_device_name(0)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
_SOC_VERSION = _get_soc_version()
|
||||
_CUSTOM_OP_ENABLED = None
|
||||
_IS_310P = None
|
||||
_SLEEP_MODE_ENABLED = None
|
||||
_CURRENT_STREAM = None
|
||||
|
||||
|
||||
def is_310p():
|
||||
return _SOC_VERSION in SOC_VERSION_INFERENCE_SERIES
|
||||
global _IS_310P
|
||||
if _IS_310P is None:
|
||||
from vllm_ascend import _build_info # type: ignore
|
||||
_IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
|
||||
return _IS_310P
|
||||
|
||||
|
||||
class NullHandle:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def wait(self):
|
||||
pass
|
||||
def sleep_mode_enabled():
|
||||
global _SLEEP_MODE_ENABLED
|
||||
if _SLEEP_MODE_ENABLED is None:
|
||||
from vllm_ascend import _build_info # type: ignore
|
||||
_SLEEP_MODE_ENABLED = _build_info.__sleep_mode_enabled__
|
||||
return _SLEEP_MODE_ENABLED
|
||||
|
||||
|
||||
def _round_up(x: int, align: int):
|
||||
if align == 0:
|
||||
return -1
|
||||
# round up x to align, for example, if align is 16, x will be rounded up to 16, 32, 48, etc.
|
||||
# input: 15, 16 -> output: 16
|
||||
# input: 17, 16 -> output: 32
|
||||
# input: 30, 16 -> output: 32
|
||||
# input: 33, 16 -> output: 48
|
||||
# ...
|
||||
return (x + align - 1) // align * align
|
||||
|
||||
|
||||
def _custom_pad(x, pad_dims):
|
||||
# pad the input tensor to the shape of pad_dims
|
||||
# input: (13, 30), pad_dims: [0, 2, 0, 3]
|
||||
# output: (16, 32)
|
||||
return torch.nn.functional.pad(x, pad_dims)
|
||||
|
||||
|
||||
def _custom_reshape(x, target_shape):
|
||||
# reshape the input tensor to the shape of target_shape
|
||||
# input: (16, 32), target_shape: [1, 16, 2, 16]
|
||||
# output: (1, 16, 2, 16)
|
||||
return x.reshape(target_shape)
|
||||
|
||||
|
||||
def _custom_transpose(x, dim1, dim2):
|
||||
# transpose the input tensor
|
||||
# input: (1, 16, 2, 16), dim1: 1, dim2: 2
|
||||
# output: (1, 2, 16, 16)
|
||||
return x.transpose(dim1, dim2)
|
||||
|
||||
|
||||
def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
|
||||
aux_dims = [0, 0, 0, 0]
|
||||
aux_dims[0] = 1
|
||||
# in_tensor: (13, 30)
|
||||
aux_dims = [1, 0, 0, 16]
|
||||
# aux_dims[1]: 16
|
||||
aux_dims[1] = _round_up(in_tensor.size(0), 16)
|
||||
# aux_dims[2]: 2
|
||||
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
|
||||
|
||||
# after: aux_dims: [1, 16, 2, 16]
|
||||
|
||||
pad_dims = [0, 0, 0, 0]
|
||||
# pad_dims[1]: 2
|
||||
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
|
||||
# pad_dims[3]: 3
|
||||
pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
|
||||
|
||||
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
|
||||
aux_dims[3] = 16
|
||||
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
|
||||
# after: pad_dims: [0, 2, 0, 3]
|
||||
|
||||
# return: (1, 2, 16, 16)
|
||||
return _custom_transpose(
|
||||
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
|
||||
2).contiguous()
|
||||
@@ -187,24 +193,19 @@ def enable_custom_op():
|
||||
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
|
||||
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
|
||||
"""
|
||||
global CUSTOM_OP_ENABLED
|
||||
|
||||
if CUSTOM_OP_ENABLED is not None:
|
||||
return CUSTOM_OP_ENABLED
|
||||
|
||||
else:
|
||||
try:
|
||||
# register custom ops into torch_library here
|
||||
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
||||
CUSTOM_OP_ENABLED = True
|
||||
|
||||
except ImportError:
|
||||
CUSTOM_OP_ENABLED = False
|
||||
logger.warning(
|
||||
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
||||
)
|
||||
|
||||
return CUSTOM_OP_ENABLED
|
||||
global _CUSTOM_OP_ENABLED
|
||||
if _CUSTOM_OP_ENABLED is not None:
|
||||
return _CUSTOM_OP_ENABLED
|
||||
try:
|
||||
# register custom ops into torch_library here
|
||||
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
||||
_CUSTOM_OP_ENABLED = True
|
||||
except ImportError:
|
||||
_CUSTOM_OP_ENABLED = False
|
||||
logger.warning(
|
||||
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
||||
)
|
||||
return _CUSTOM_OP_ENABLED
|
||||
|
||||
|
||||
def find_hccl_library() -> str:
|
||||
@@ -229,9 +230,6 @@ def find_hccl_library() -> str:
|
||||
return so_file
|
||||
|
||||
|
||||
_current_stream = None
|
||||
|
||||
|
||||
def current_stream() -> torch.npu.Stream:
|
||||
"""
|
||||
replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
|
||||
@@ -241,12 +239,12 @@ def current_stream() -> torch.npu.Stream:
|
||||
directly, so that we can avoid calling `torch.npu.current_stream()`.
|
||||
|
||||
"""
|
||||
global _current_stream
|
||||
if _current_stream is None:
|
||||
global _CURRENT_STREAM
|
||||
if _CURRENT_STREAM is None:
|
||||
# when this function is called before any stream is set,
|
||||
# we return the default stream.
|
||||
_current_stream = torch.npu.current_stream()
|
||||
return _current_stream
|
||||
_CURRENT_STREAM = torch.npu.current_stream()
|
||||
return _CURRENT_STREAM
|
||||
|
||||
|
||||
def adapt_patch(is_global_patch: bool = False):
|
||||
@@ -326,6 +324,7 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||
len(original_sizes))
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
|
||||
|
||||
@@ -378,10 +377,12 @@ class ProfileExecuteDuration:
|
||||
return durations
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
||||
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
def npu_wait_tensor(self: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
*,
|
||||
|
||||
Reference in New Issue
Block a user