[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -21,11 +21,12 @@ import atexit
|
||||
import math
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
import torch_npu # noqa: F401 # noqa: F401
|
||||
import torchair # type: ignore[import] # noqa: F401
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from torch_npu.npu.streams import Event
|
||||
@@ -57,6 +58,116 @@ ASCEND_QUATIZATION_METHOD = "ascend"
|
||||
|
||||
CUSTOM_OP_ENABLED = None
|
||||
|
||||
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
||||
|
||||
ACL_FORMAT_FRACTAL_ND = 2
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_soc_version():
|
||||
"""Gets the SOC version and caches it."""
|
||||
if not torch.npu.is_available():
|
||||
return ""
|
||||
device_count = torch.npu.device_count()
|
||||
if device_count <= 0:
|
||||
return ""
|
||||
try:
|
||||
return torch.npu.get_device_name(0)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
_SOC_VERSION = _get_soc_version()
|
||||
|
||||
|
||||
def is_310p():
|
||||
return _SOC_VERSION in SOC_VERSION_INFERENCE_SERIES
|
||||
|
||||
|
||||
class NullHandle:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def wait(self):
|
||||
pass
|
||||
|
||||
|
||||
def _round_up(x: int, align: int):
|
||||
if align == 0:
|
||||
return -1
|
||||
return (x + align - 1) // align * align
|
||||
|
||||
|
||||
def _custom_pad(x, pad_dims):
|
||||
return torch.nn.functional.pad(x, pad_dims)
|
||||
|
||||
|
||||
def _custom_reshape(x, target_shape):
|
||||
return x.reshape(target_shape)
|
||||
|
||||
|
||||
def _custom_transpose(x, dim1, dim2):
|
||||
return x.transpose(dim1, dim2)
|
||||
|
||||
|
||||
def nd_to_nz_2d(in_tensor: torch.Tensor) -> torch.Tensor:
|
||||
aux_dims = [0, 0, 0, 0]
|
||||
aux_dims[0] = 1
|
||||
aux_dims[1] = _round_up(in_tensor.size(0), 16)
|
||||
|
||||
pad_dims = [0, 0, 0, 0]
|
||||
pad_dims[3] = _round_up(in_tensor.size(0), 16) - in_tensor.size(0)
|
||||
|
||||
aux_dims[2] = _round_up(in_tensor.size(1), 16) // 16
|
||||
aux_dims[3] = 16
|
||||
pad_dims[1] = _round_up(in_tensor.size(1), 16) - in_tensor.size(1)
|
||||
|
||||
return _custom_transpose(
|
||||
_custom_reshape(_custom_pad(in_tensor, pad_dims), aux_dims), 1,
|
||||
2).contiguous()
|
||||
|
||||
|
||||
def nd_to_nz_spec(mask_tensor: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = mask_tensor.shape[0]
|
||||
max_seq_len = mask_tensor.shape[1]
|
||||
|
||||
tokens_pad = (num_tokens + 15) // 16 * 16
|
||||
max_seq_len_pad = (max_seq_len + 15) // 16 * 16
|
||||
|
||||
mask_tensor_pad = \
|
||||
torch.zeros((1, tokens_pad, max_seq_len_pad), dtype=mask_tensor.dtype, device=mask_tensor.device)
|
||||
mask_tensor_pad[0][:num_tokens, :max_seq_len] = mask_tensor
|
||||
mask = mask_tensor_pad.reshape(
|
||||
(1, tokens_pad, max_seq_len_pad // 16, 16)).permute(0, 2, 1, 3)
|
||||
return mask
|
||||
|
||||
|
||||
def aligned_16(tensor: torch.Tensor):
|
||||
"""Aligned tensor for 310P"""
|
||||
|
||||
# Get the size of the current 0th dimension
|
||||
n = tensor.size(0)
|
||||
|
||||
# Calculate the aligned size
|
||||
n_aligned = ((n + 15) // 16) * 16
|
||||
|
||||
# If already aligned, return the original tensor
|
||||
if n == n_aligned:
|
||||
return tensor
|
||||
|
||||
# Create a new tensor with shape (n_aligned, H, W) and fill it with zeros
|
||||
new_tensor = torch.zeros(n_aligned,
|
||||
*tensor.shape[1:],
|
||||
dtype=tensor.dtype,
|
||||
device=tensor.device)
|
||||
|
||||
# Copy the original tensor to the first N positions of the new tensor
|
||||
new_tensor[:n] = tensor
|
||||
|
||||
return new_tensor
|
||||
|
||||
|
||||
def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
import importlib
|
||||
|
||||
Reference in New Issue
Block a user