### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
408 lines
13 KiB
Python
408 lines
13 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/vllm/worker/worker.py
|
|
#
|
|
|
|
import atexit
|
|
import math
|
|
from contextlib import contextmanager, nullcontext
|
|
from enum import Enum
|
|
from functools import lru_cache
|
|
from threading import Lock
|
|
from typing import TYPE_CHECKING, List, Tuple
|
|
|
|
import torch
|
|
import torch_npu # noqa: F401 # noqa: F401
|
|
import torchair # type: ignore[import] # noqa: F401
|
|
from packaging.version import InvalidVersion, Version
|
|
from torch_npu.npu.streams import Event
|
|
from vllm.logger import logger
|
|
|
|
import vllm_ascend.envs as envs
|
|
|
|
try:
|
|
# Recent release of torchair has moved these ops to `.scope`.
|
|
from torchair.scope import npu_stream_switch as _npu_stream_switch
|
|
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
|
|
except ImportError:
|
|
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
|
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
else:
|
|
VllmConfig = None
|
|
|
|
# NOTE: Currently, we can only capture 1920 graphs at most,
|
|
# due to the limitation of ACL graph. This number is bounded by
|
|
# the number of streams, which is 2048, we save 128 streams
|
|
# as a buffer.
|
|
# Maximum number of graphs that can be captured by ACL Graph
|
|
MAX_CAPTURE_SIZE = 1920
|
|
|
|
ASCEND_QUATIZATION_METHOD = "ascend"
|
|
|
|
CUSTOM_OP_ENABLED = None
|
|
|
|
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
|
|
|
ACL_FORMAT_FRACTAL_ND = 2
|
|
ACL_FORMAT_FRACTAL_NZ = 29
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def _get_soc_version():
|
|
"""Gets the SOC version and caches it."""
|
|
if not torch.npu.is_available():
|
|
return ""
|
|
device_count = torch.npu.device_count()
|
|
if device_count <= 0:
|
|
return ""
|
|
try:
|
|
return torch.npu.get_device_name(0)
|
|
except Exception:
|
|
return ""
|
|
|
|
|
|
_SOC_VERSION = _get_soc_version()
|
|
|
|
|
|
def is_310p():
|
|
return _SOC_VERSION in SOC_VERSION_INFERENCE_SERIES
|
|
|
|
|
|
class NullHandle:
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def wait(self):
|
|
pass
|
|
|
|
|
|
def _round_up(x: int, align: int):
|
|
if align == 0:
|
|
return -1
|
|
return (x + align - 1) // align * align
|
|
|
|
|
|
def _custom_pad(x, pad_dims):
|
|
return torch.nn.functional.pad(x, pad_dims)
|
|
|
|
|
|
def _custom_reshape(x, target_shape):
|
|
return x.reshape(target_shape)
|
|
|
|
|
|
def _custom_transpose(x, dim1, dim2):
|
|
return x.transpose(dim1, dim2)
|
|
|
|
|
|
def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
|
|
aux_dims = [0, 0, 0, 0]
|
|
aux_dims[0] = 1
|
|
aux_dims[1] = _round_up(in_tensor.size(0), 16)
|
|
|
|
pad_dims = [0, 0, 0, 0]
|
|
pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
|
|
|
|
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
|
|
aux_dims[3] = 16
|
|
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
|
|
|
|
return _custom_transpose(
|
|
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
|
|
2).contiguous()
|
|
|
|
|
|
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
|
num_tokens = mask_tensor.shape[0]
|
|
max_seq_len = mask_tensor.shape[1]
|
|
|
|
tokens_pad = (num_tokens + 15) // 16 * 16
|
|
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
|
|
|
|
mask_tensor_pad = \
|
|
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
|
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
|
|
mask = mask_tensor_pad.reshape(
|
|
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
|
return mask
|
|
|
|
|
|
def aligned_16(tensor: torch.Tensor):
|
|
"""Aligned tensor for 310P"""
|
|
|
|
# Get the size of the current 0th dimension
|
|
n = tensor.size(0)
|
|
|
|
# Calculate the aligned size
|
|
n_aligned = ((n + 15) // 16) * 16
|
|
|
|
# If already aligned, return the original tensor
|
|
if n == n_aligned:
|
|
return tensor
|
|
|
|
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
|
|
new_tensor = torch.zeros(n_aligned,
|
|
*tensor.shape[1:],
|
|
dtype=tensor.dtype,
|
|
device=tensor.device)
|
|
|
|
# Copy the original tensor to the first N positions of the new tensor
|
|
new_tensor[:n] = tensor
|
|
|
|
return new_tensor
|
|
|
|
|
|
def try_register_lib(lib_name: str, lib_info: str = ""):
|
|
import importlib
|
|
import importlib.util
|
|
try:
|
|
module_spec = importlib.util.find_spec(lib_name)
|
|
if module_spec is not None:
|
|
importlib.import_module(lib_name)
|
|
if lib_info:
|
|
logger.info(lib_info)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def enable_custom_op():
|
|
"""
|
|
Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component.
|
|
Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device().
|
|
"""
|
|
global CUSTOM_OP_ENABLED
|
|
|
|
if CUSTOM_OP_ENABLED is not None:
|
|
return CUSTOM_OP_ENABLED
|
|
|
|
else:
|
|
try:
|
|
# register custom ops into torch_library here
|
|
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
|
CUSTOM_OP_ENABLED = True
|
|
|
|
except ImportError:
|
|
CUSTOM_OP_ENABLED = False
|
|
logger.warning(
|
|
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
|
)
|
|
|
|
return CUSTOM_OP_ENABLED
|
|
|
|
|
|
def find_hccl_library() -> str:
|
|
"""
|
|
We either use the library file specified by the `HCCL_SO_PATH`
|
|
environment variable, or we find the library file brought by PyTorch.
|
|
After importing `torch`, `libhccl.so` can be
|
|
found by `ctypes` automatically.
|
|
"""
|
|
so_file = envs.HCCL_SO_PATH
|
|
|
|
# manually load the hccl library
|
|
if so_file:
|
|
logger.info("Found hccl from environment variable HCCL_SO_PATH=%s",
|
|
so_file)
|
|
else:
|
|
if torch.version.cann is not None:
|
|
so_file = "libhccl.so"
|
|
else:
|
|
raise ValueError("HCCL only supports Ascend NPU backends.")
|
|
logger.info("Found hccl from library %s", so_file)
|
|
return so_file
|
|
|
|
|
|
_current_stream = None
|
|
|
|
|
|
def current_stream() -> torch.npu.Stream:
|
|
"""
|
|
replace `torch.npu.current_stream()` with `vllm.utils.current_stream()`.
|
|
it turns out that `torch.npu.current_stream()` is quite expensive,
|
|
as it will construct a new stream object at each call.
|
|
here we patch `torch.npu.set_stream` to keep track of the current stream
|
|
directly, so that we can avoid calling `torch.npu.current_stream()`.
|
|
|
|
"""
|
|
global _current_stream
|
|
if _current_stream is None:
|
|
# when this function is called before any stream is set,
|
|
# we return the default stream.
|
|
_current_stream = torch.npu.current_stream()
|
|
return _current_stream
|
|
|
|
|
|
def adapt_patch(is_global_patch: bool = False):
|
|
if is_global_patch:
|
|
from vllm_ascend.patch import platform # noqa: F401
|
|
else:
|
|
from vllm_ascend.patch import worker # noqa: F401
|
|
|
|
|
|
def vllm_version_is(target_vllm_version: str):
|
|
if envs.VLLM_VERSION is not None:
|
|
vllm_version = envs.VLLM_VERSION
|
|
else:
|
|
import vllm
|
|
vllm_version = vllm.__version__
|
|
try:
|
|
return Version(vllm_version) == Version(target_vllm_version)
|
|
except InvalidVersion:
|
|
raise ValueError(
|
|
f"Invalid vllm version {vllm_version} found. A dev version of vllm "
|
|
"is installed probably. Set the environment variable VLLM_VERSION "
|
|
"to control it by hand. And please make sure the value follows the "
|
|
"format of x.y.z.")
|
|
|
|
|
|
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
|
"""Update ACL graph capture sizes based on hardware limitations"""
|
|
# Store original configuration and temporarily clear it
|
|
compilation_config = vllm_config.compilation_config
|
|
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
|
compilation_config.cudagraph_capture_sizes, None
|
|
|
|
# Calculate parallel configuration factor
|
|
num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
|
parallel_config = vllm_config.parallel_config
|
|
|
|
# TODO: Find out whether we need to take into account the pp_size
|
|
parallel_factor = 1 + sum(size > 1 for size in [
|
|
parallel_config.data_parallel_size_local,
|
|
parallel_config.tensor_parallel_size,
|
|
parallel_config.expert_parallel_size,
|
|
parallel_config.expert_tensor_parallel_size,
|
|
])
|
|
|
|
# Calculate maximum supported batch sizes considering model architecture
|
|
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
|
|
(num_hidden_layers + 1) / parallel_factor)
|
|
logger.info("Calculated maximum supported batch sizes for ACL graph: %s",
|
|
max_num_batch_sizes)
|
|
|
|
# If original sizes exceed maximum, sample a representative subset
|
|
if max_num_batch_sizes < len(original_sizes):
|
|
# Sample uniformly from original sizes
|
|
step = (len(original_sizes) - 1) / (max_num_batch_sizes - 1)
|
|
indices = [round(i * step) for i in range(max_num_batch_sizes)]
|
|
|
|
# Ensure first and last elements are preserved
|
|
indices[0], indices[-1] = 0, len(original_sizes) - 1
|
|
|
|
sampled_sizes = [original_sizes[i] for i in indices]
|
|
compilation_config.init_with_cudagraph_sizes(sampled_sizes)
|
|
|
|
logger.info(
|
|
"Adjusted ACL graph batch sizes for %s model (layers: %d): %d → %d sizes",
|
|
vllm_config.model_config.architectures[0],
|
|
num_hidden_layers,
|
|
len(original_sizes),
|
|
len(compilation_config.
|
|
cudagraph_capture_sizes # type: ignore[arg-type]
|
|
))
|
|
else:
|
|
# No adjustment needed
|
|
compilation_config.cudagraph_capture_sizes = original_sizes
|
|
logger.info(
|
|
"No adjustment needed for ACL graph batch sizes: %s model (layers: %d) with %d sizes",
|
|
vllm_config.model_config.architectures[0], num_hidden_layers,
|
|
len(original_sizes))
|
|
|
|
|
|
def dispose_tensor(x: torch.Tensor):
|
|
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
|
|
|
|
|
|
class ProfileExecuteDuration:
|
|
_instance = None
|
|
_observations: List[Tuple[str, Event, Event]] = []
|
|
_lock = Lock()
|
|
|
|
def __new__(cls):
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
atexit.register(cls._instance.destroy)
|
|
return cls._instance
|
|
|
|
def destroy(self):
|
|
with self._lock:
|
|
self._observations.clear()
|
|
|
|
@contextmanager
|
|
def capture_async(self, duration_tag: str):
|
|
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
|
yield
|
|
return
|
|
|
|
observe_start = Event(enable_timing=True)
|
|
observe_start.record()
|
|
try:
|
|
yield
|
|
finally:
|
|
observe_end = Event(enable_timing=True)
|
|
observe_end.record()
|
|
with self._lock:
|
|
self._observations.append(
|
|
(duration_tag, observe_start, observe_end))
|
|
|
|
def pop_captured_sync(self) -> dict:
|
|
"""Pop and synchronize all events in the observation list"""
|
|
durations: dict[str, float] = {}
|
|
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
|
|
return durations
|
|
|
|
while self._observations:
|
|
with self._lock:
|
|
tag, observe_start, observe_end = self._observations.pop()
|
|
observe_end.synchronize()
|
|
durations[tag] = observe_start.elapsed_time(observe_end)
|
|
|
|
return durations
|
|
|
|
|
|
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
|
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
|
|
|
|
|
def npu_wait_tensor(self: torch.Tensor,
|
|
dependency: torch.Tensor,
|
|
*,
|
|
enabled: bool = True):
|
|
return _npu_wait_tensor(self, dependency) if enabled else self
|
|
|
|
|
|
# TODO(zzzzwwjj): move this into forward_context
|
|
class FusedMoEState(Enum):
|
|
AllGather = 0
|
|
All2All = 1
|
|
MC2 = 2
|
|
|
|
|
|
# TODO(zzzzwwjj): add soc_version to choose branch
|
|
def get_fused_moe_state(ep_size: int, with_prefill: bool):
|
|
if ep_size == 1:
|
|
return FusedMoEState.AllGather
|
|
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
|
elif ep_size < 16 or with_prefill:
|
|
return FusedMoEState.All2All
|
|
else:
|
|
return FusedMoEState.MC2
|