[refact] unified soc_version code (#4359)
### What this PR does / why we need it?
Currently, there are two paths to judge the chip type in code,
`get_ascend_soc_version` use `get_soc_version` api in torch_npu, and
`is_310p` `use _build_info.__soc_version__`, which generate when
install. We need to unify the two paths.
We need to unify these codes based on the following points:
1. We need to ensure consistency in chip type judgment between compiling
and running states;
2. In compiling state, we need chip type to complete op's compilation,
but in running state, we only need device
type(910B/910_93/310P/910_95/etc) to make code branch judgement;
3. In compiling state, torch_npu may not have been installed yet, so we
can't use torch_npu's api.
Based on the above points, we have made the following changes:
1. When user set env `SOC_VERSION`, use it; when not set, query
soc_version by `npu-smi`;
2. generate device_type based on soc_version when compiling, and write
`__device_type__` instead of `__soc_version__` in `_build_info.py`;
3. In running state, use `__device_type__` to judge code branch.
### Does this PR introduce _any_ user-facing change?
When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default,
we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in
the list `soc_to_device` in `setup.py`.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -42,9 +42,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec,
|
||||
prefill_context_parallel_enable,
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
aligned_16, get_ascend_device_type, nd_to_nz_2d,
|
||||
nd_to_nz_spec, prefill_context_parallel_enable,
|
||||
weak_ref_tensors)
|
||||
|
||||
# isort: off
|
||||
@@ -83,7 +83,7 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
return (2, num_blocks, num_kv_heads * head_size // 16, block_size,
|
||||
16)
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
@@ -351,7 +351,7 @@ class AscendAttentionMetadataBuilder:
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
if attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
@@ -702,7 +702,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
mask = attn_metadata.attn_mask
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
@@ -783,7 +783,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# seq_lens_tensor needs to be transferred to the device for 310P.
|
||||
attn_metadata.seq_lens = \
|
||||
attn_metadata.seq_lens.to(device=query.device)
|
||||
@@ -857,7 +857,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# Do reformat in case of broadcasted tensors.
|
||||
attn_metadata.attn_mask = \
|
||||
torch_npu.npu_format_cast(attn_metadata.attn_mask.contiguous(),
|
||||
|
||||
@@ -32,7 +32,7 @@ from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
prefill_context_parallel_enable)
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
@@ -376,7 +376,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
self.local_agent_metadata.cluster_id)
|
||||
self.init_llm_datadist()
|
||||
self.finished_reqs: set[str] = set()
|
||||
self.soc_info = get_ascend_soc_version()
|
||||
self.soc_info = get_ascend_device_type()
|
||||
# Set hccl deterministic for model execute
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
self.done_receiving_counts: defaultdict[str,
|
||||
@@ -761,7 +761,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
decode_server_device_info)
|
||||
|
||||
if self.soc_info == AscendSocVersion.A3:
|
||||
if self.soc_info == AscendDeviceType._910_93:
|
||||
# generate super_pod_list for rank table
|
||||
super_pod_list = []
|
||||
prefill_super_pod_info = {
|
||||
|
||||
@@ -50,11 +50,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# value is None, which means the system default C compiler will be used.
|
||||
"C_COMPILER":
|
||||
lambda: os.getenv("C_COMPILER", None),
|
||||
# The version of the Ascend chip. If not set, the default value is
|
||||
# ASCEND910B1(Available for A2 and A3 series). It's used for package building.
|
||||
# The version of the Ascend chip. It's used for package building.
|
||||
# If not set, we will query chip info through `npu-smi`.
|
||||
# Please make sure that the version is correct.
|
||||
"SOC_VERSION":
|
||||
lambda: os.getenv("SOC_VERSION", "ASCEND910B1"),
|
||||
lambda: os.getenv("SOC_VERSION", None),
|
||||
# If set, vllm-ascend will print verbose logs during compilation
|
||||
"VERBOSE":
|
||||
lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
||||
|
||||
@@ -4,9 +4,9 @@ from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
|
||||
@@ -33,10 +33,10 @@ class AscendSiluAndMul(SiluAndMul):
|
||||
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
|
||||
@@ -43,9 +43,9 @@ from vllm_ascend.quantization.w4a8_dynamic import \
|
||||
AscendW4A8DynamicFusedMoEMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
|
||||
is_enable_nz, npu_stream_switch,
|
||||
shared_expert_dp_enabled,
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
enable_sp, get_ascend_device_type, is_enable_nz,
|
||||
npu_stream_switch, shared_expert_dp_enabled,
|
||||
shared_experts_calculation_stream)
|
||||
|
||||
|
||||
@@ -79,7 +79,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
if not is_310p() and is_enable_nz():
|
||||
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
|
||||
):
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
||||
@@ -22,7 +22,8 @@ from torch.nn.functional import pad
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.utils import dispose_tensor, is_310p
|
||||
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
|
||||
get_ascend_device_type)
|
||||
|
||||
|
||||
def cumsum_group_list(group_list: torch.Tensor,
|
||||
@@ -210,7 +211,7 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
|
||||
@@ -30,7 +30,7 @@ from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
@@ -98,11 +98,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = (
|
||||
get_ascend_soc_version() == AscendSocVersion.A3)
|
||||
get_ascend_device_type() == AscendDeviceType._910_93)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_soc_version() == AscendSocVersion.A3
|
||||
get_ascend_device_type() == AscendDeviceType._910_93
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
|
||||
@@ -32,9 +32,10 @@ def _addrmsnorm_forward_oot(
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
if layer is not None and not is_310p():
|
||||
if layer is not None and get_ascend_device_type(
|
||||
) != AscendDeviceType._310P:
|
||||
layer_cls_name = layer.__class__.__name__
|
||||
try:
|
||||
weight_prefetch_method = get_forward_context(
|
||||
@@ -67,7 +68,7 @@ def _addrmsnorm_forward_oot(
|
||||
)
|
||||
|
||||
else:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
@@ -195,9 +196,9 @@ class AscendGemmaRMSNorm(GemmaRMSNorm):
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
if residual is not None:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
@@ -27,7 +27,8 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
from vllm.platforms import CpuArchEnum
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import enable_custom_op, is_310p
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
get_ascend_device_type)
|
||||
|
||||
|
||||
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
@@ -49,8 +50,9 @@ def _rope_forward_oot(
|
||||
if self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if _custom_rotary_embedding_enabled(query, is_neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
if _custom_rotary_embedding_enabled(
|
||||
query, is_neox_style, self.head_size) and get_ascend_device_type(
|
||||
) != AscendDeviceType._310P:
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
|
||||
def parallel_config_get_dp_port(self) -> int:
|
||||
@@ -111,5 +111,5 @@ def communication_adaptation_310p():
|
||||
torch.distributed.distributed_c10d.all_reduce)
|
||||
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
communication_adaptation_310p()
|
||||
|
||||
@@ -30,8 +30,9 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
delete_torchair_cache_file)
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p,
|
||||
is_vl_model, prefill_context_parallel_enable,
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, AscendDeviceType,
|
||||
enable_sp, get_ascend_device_type, is_vl_model,
|
||||
prefill_context_parallel_enable,
|
||||
update_aclgraph_sizes,
|
||||
update_cudagraph_capture_sizes,
|
||||
update_default_aclgraph_sizes)
|
||||
@@ -281,7 +282,7 @@ class NPUPlatform(Platform):
|
||||
cache_config.block_size = origin_block_size
|
||||
|
||||
# Activate custom ops for v1, except on 310P
|
||||
if not is_310p():
|
||||
if get_ascend_device_type() != AscendDeviceType._310P:
|
||||
compilation_config.custom_ops = ["all"]
|
||||
|
||||
# If ascend_scheduler_config is enabled,
|
||||
|
||||
@@ -25,7 +25,8 @@ from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
get_ascend_device_type, is_enable_nz)
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
@@ -45,7 +46,8 @@ class AscendW8A8LinearMethod:
|
||||
|
||||
def __init__(self) -> None:
|
||||
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
||||
self.transpose_weight = not is_310p()
|
||||
self.transpose_weight = get_ascend_device_type(
|
||||
) != AscendDeviceType._310P
|
||||
|
||||
@staticmethod
|
||||
def get_weight(
|
||||
@@ -147,7 +149,7 @@ class AscendW8A8LinearMethod:
|
||||
)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# On 300I Duo platform, we need transpose again if
|
||||
# using nz. This transpose can be skipped in torchair.
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
@@ -299,7 +301,7 @@ class AscendW8A8FusedMoEMethod:
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
return fused_experts_310p(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
@@ -328,7 +330,7 @@ class AscendW8A8FusedMoEMethod:
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if not is_310p():
|
||||
if get_ascend_device_type() != AscendDeviceType._310P:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
@@ -345,7 +347,7 @@ class AscendW8A8FusedMoEMethod:
|
||||
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
||||
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.max())
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
@@ -365,7 +367,8 @@ class AscendW8A8FusedMoEMethod:
|
||||
# converting ACL_FORMAT_FRACTAL_NZ.
|
||||
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
||||
# ACL_FORMAT_FRACTAL_NZ.
|
||||
if not is_310p() and is_enable_nz():
|
||||
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
|
||||
):
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch_npu
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
|
||||
|
||||
@@ -25,7 +25,8 @@ class AscendTopKTopPSampler(TopKTopPSampler):
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# npu_top_k_top_p uses the operator aclnnApplyTopKTopP, but aclnnApplyTopKTopP currently does not support 310P
|
||||
if not is_310p() and p is not None and k is not None and 1 <= int(
|
||||
if get_ascend_device_type(
|
||||
) != AscendDeviceType._310P and p is not None and k is not None and 1 <= int(
|
||||
k.max()) <= 1024:
|
||||
# npu_top_k_top_p's parameter order is (logits, p, k), not (logits, k, p)
|
||||
return torch_npu.npu_top_k_top_p(logits, p, k)
|
||||
|
||||
@@ -57,7 +57,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
get_ascend_device_type)
|
||||
|
||||
_ROUTER_SCALE = None
|
||||
|
||||
@@ -448,7 +449,8 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
# on 300I Duo platform, we find that num_voted_experts set to 5 achieves
|
||||
# good performance without sacrifice too much accuracy. for other platform,
|
||||
# this is set to 8 to use original pangu grouped topk.
|
||||
num_voted_experts = 5 if is_310p() else 8
|
||||
num_voted_experts = 5 if get_ascend_device_type(
|
||||
) == AscendDeviceType._310P else 8
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
@@ -1109,7 +1111,8 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
if is_310p() and "head" in name:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._310P and "head" in name:
|
||||
# on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than
|
||||
# ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented
|
||||
# by linear, we manually cast the format here.
|
||||
|
||||
@@ -28,9 +28,9 @@ def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
|
||||
@@ -51,8 +51,8 @@ from vllm_ascend.torchair.utils import (get_all_reduce_merge_state,
|
||||
get_rm_router_logits_state,
|
||||
npu_stream_switch, npu_wait_tensor,
|
||||
super_kernel)
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_ascend_soc_version, is_310p,
|
||||
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
|
||||
get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
@@ -75,11 +75,11 @@ def torchair_fused_experts_with_mc2(
|
||||
ep_world_size = moe_parallel_config.ep_size
|
||||
|
||||
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
need_extra_args = (get_ascend_device_type() == AscendDeviceType._910_93
|
||||
or is_torchair)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
a3_need_extra_args = get_ascend_device_type() == AscendDeviceType._910_93
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
@@ -467,7 +467,7 @@ def torchair_fused_experts_moge(
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
|
||||
@@ -57,9 +57,9 @@ def torchair_rmsnorm_forward_oot(
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
if residual is not None:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
@@ -25,7 +25,8 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import enable_custom_op, is_310p
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
get_ascend_device_type)
|
||||
|
||||
|
||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
@@ -60,8 +61,9 @@ def rope_forward_oot(
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if custom_rotary_embedding_enabled(query, neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
if custom_rotary_embedding_enabled(
|
||||
query, neox_style, self.head_size) and get_ascend_device_type(
|
||||
) != AscendDeviceType._310P:
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
|
||||
@@ -28,8 +28,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
|
||||
super_kernel)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version,
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
dispose_tensor, get_ascend_device_type,
|
||||
is_enable_nz,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
@@ -234,11 +234,11 @@ def torchair_fused_experts_with_mc2(
|
||||
ep_world_size = ep_group.world_size
|
||||
|
||||
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
need_extra_args = (get_ascend_device_type() == AscendDeviceType._910_93
|
||||
or is_torchair)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
a3_need_extra_args = get_ascend_device_type() == AscendDeviceType._910_93
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
|
||||
@@ -34,8 +34,8 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
aligned_16, get_ascend_device_type, nd_to_nz_2d)
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackend(AscendAttentionBackend):
|
||||
@@ -185,7 +185,8 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._310P and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
|
||||
|
||||
@@ -381,7 +382,7 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
|
||||
@@ -42,8 +42,7 @@ from vllm_ascend.torchair.utils import (
|
||||
register_torchair_model, torchair_ops_patch,
|
||||
torchair_quant_method_register, write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_310p, get_ascend_soc_version,
|
||||
AscendSocVersion)
|
||||
AscendDeviceType, get_ascend_device_type)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
@@ -125,13 +124,13 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
max_num_tokens, tp_size)
|
||||
self.mc2_tokens_capacity = max_graph_batch_size
|
||||
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512:
|
||||
logger.error(
|
||||
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256:
|
||||
logger.error(
|
||||
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
@@ -207,7 +206,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
positions, attn_metadata, num_tokens,
|
||||
intermediate_tensors, inputs_embeds):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
hidden_states = super()._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
@@ -230,7 +229,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
assert isinstance(kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
|
||||
@@ -371,7 +370,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
@@ -384,7 +383,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
@@ -414,7 +413,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_distributed import \
|
||||
@@ -428,7 +427,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.ascend_config.torchair_graph_config.enable_frozen_parameter
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.tiling_schedule_optimize = get_ascend_device_type(
|
||||
) != AscendDeviceType._310P
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
self.ascend_config.torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
@@ -531,8 +531,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
|
||||
# on all EP ranks
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910_93 and self.parallel_config.enable_expert_parallel:
|
||||
self._align_graph_size_divisible_by_tp_size()
|
||||
|
||||
def _align_graph_size_divisible_by_tp_size(self):
|
||||
|
||||
@@ -48,7 +48,6 @@ ACL_FORMAT_FRACTAL_ND = 2
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
|
||||
_CUSTOM_OP_ENABLED = None
|
||||
_IS_310P = None
|
||||
_SLEEP_MODE_ENABLED = None
|
||||
_CURRENT_STREAM = None
|
||||
_PREFETCH_STREAM = None
|
||||
@@ -121,14 +120,6 @@ def _unregister_print_streams_on_exit():
|
||||
atexit.register(_unregister_print_streams_on_exit)
|
||||
|
||||
|
||||
def is_310p():
|
||||
global _IS_310P
|
||||
if _IS_310P is None:
|
||||
from vllm_ascend import _build_info # type: ignore
|
||||
_IS_310P = _build_info.__soc_version__.lower().startswith("ascend310p")
|
||||
return _IS_310P
|
||||
|
||||
|
||||
def is_enable_nz():
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
|
||||
|
||||
@@ -703,32 +694,47 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): Currently there is no clear SOC_VERSION policy for A2 and A3 in CANN.
|
||||
# So we get the version dynamically. In the future, we should get the version info from _build_info like 310p does.
|
||||
class AscendSocVersion(Enum):
|
||||
A2 = 0
|
||||
A3 = 1
|
||||
UNDEFINED = 2
|
||||
class AscendDeviceType(Enum):
|
||||
_910B = 0 # A2
|
||||
_910_93 = 1 # A3
|
||||
_310P = 2
|
||||
_910_95 = 3 # A5
|
||||
|
||||
|
||||
_ascend_soc_version = None
|
||||
_ascend_device_type = None
|
||||
|
||||
|
||||
def init_ascend_soc_version():
|
||||
def _init_ascend_device_type():
|
||||
global _ascend_device_type
|
||||
from vllm_ascend import _build_info # type: ignore
|
||||
_ascend_device_type = AscendDeviceType[_build_info.__device_type__]
|
||||
|
||||
|
||||
def check_ascend_device_type():
|
||||
global _ascend_device_type
|
||||
if _ascend_device_type is None:
|
||||
_init_ascend_device_type()
|
||||
|
||||
soc_version = torch_npu.npu.get_soc_version()
|
||||
global _ascend_soc_version
|
||||
if 220 <= soc_version <= 225:
|
||||
_ascend_soc_version = AscendSocVersion.A2
|
||||
cur_device_type = AscendDeviceType._910B
|
||||
elif 250 <= soc_version <= 255:
|
||||
_ascend_soc_version = AscendSocVersion.A3
|
||||
cur_device_type = AscendDeviceType._910_93
|
||||
elif 200 <= soc_version <= 205:
|
||||
cur_device_type = AscendDeviceType._310P
|
||||
elif soc_version == 260:
|
||||
cur_device_type = AscendDeviceType._910_95
|
||||
else:
|
||||
_ascend_soc_version = AscendSocVersion.UNDEFINED
|
||||
raise RuntimeError(f"Can not support soc_version: {soc_version}.")
|
||||
|
||||
assert _ascend_device_type == cur_device_type, f"Current device type: {cur_device_type} does not match the installed version's device type: {_ascend_device_type}, please check your installation package."
|
||||
|
||||
|
||||
def get_ascend_soc_version():
|
||||
global _ascend_soc_version
|
||||
assert _ascend_soc_version is not None
|
||||
return _ascend_soc_version
|
||||
def get_ascend_device_type():
|
||||
global _ascend_device_type
|
||||
if _ascend_device_type is None:
|
||||
_init_ascend_device_type()
|
||||
return _ascend_device_type
|
||||
|
||||
|
||||
def lmhead_tp_enable() -> bool:
|
||||
|
||||
@@ -138,9 +138,9 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_soc_version, is_310p,
|
||||
is_enable_nz, is_moe_model, lmhead_tp_enable,
|
||||
AscendDeviceType, ProfileExecuteDuration,
|
||||
enable_sp, get_ascend_device_type, is_enable_nz,
|
||||
is_moe_model, lmhead_tp_enable,
|
||||
prefill_context_parallel_enable)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@@ -161,7 +161,7 @@ import torch_npu
|
||||
# if true, allow tensor initialization and casting with internal format (e.g., NZ)
|
||||
torch.npu.config.allow_internal_format = True
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
torch_npu.npu.set_compile_mode(jit_compile=False)
|
||||
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
|
||||
else:
|
||||
@@ -2226,14 +2226,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not is_moe_model(self.vllm_config):
|
||||
return None
|
||||
|
||||
soc_version = get_ascend_soc_version()
|
||||
soc_version = get_ascend_device_type()
|
||||
quant_type = getattr(self.vllm_config.model_config.hf_config,
|
||||
'moe_quantize', None)
|
||||
model_type = self.vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if not self.parallel_config.enable_expert_parallel:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
elif soc_version in {AscendSocVersion.A2}:
|
||||
elif soc_version in {AscendDeviceType._910B}:
|
||||
if (num_tokens <= self.mc2_tokens_capacity
|
||||
and self.parallel_config.world_size_across_dp >= 16):
|
||||
moe_comm_type = MoECommType.MC2
|
||||
@@ -2244,7 +2244,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
elif soc_version in {AscendSocVersion.A3}:
|
||||
elif soc_version in {AscendDeviceType._910_93}:
|
||||
moe_comm_type = (MoECommType.MC2
|
||||
if num_tokens <= self.mc2_tokens_capacity else
|
||||
MoECommType.ALLTOALL)
|
||||
@@ -3183,7 +3183,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
if self.dynamic_eplb:
|
||||
model_register(self.model, self.model_config)
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
@@ -50,7 +50,7 @@ from vllm_ascend.cpu_binding import bind_cpus
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import (init_ascend_soc_version, is_enable_nz,
|
||||
from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz,
|
||||
prefill_context_parallel_enable,
|
||||
register_ascend_customop, sleep_mode_enabled,
|
||||
try_register_lib)
|
||||
@@ -91,7 +91,7 @@ class NPUWorker(WorkerBase):
|
||||
register_ascend_customop(vllm_config)
|
||||
# init ascend config and soc version
|
||||
init_ascend_config(vllm_config)
|
||||
init_ascend_soc_version()
|
||||
check_ascend_device_type()
|
||||
use_sparse = False
|
||||
if vllm_config.model_config is not None:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
|
||||
Reference in New Issue
Block a user