[refact] unified soc_version code (#4359)
### What this PR does / why we need it?
Currently, there are two paths to judge the chip type in code,
`get_ascend_soc_version` use `get_soc_version` api in torch_npu, and
`is_310p` `use _build_info.__soc_version__`, which generate when
install. We need to unify the two paths.
We need to unify these codes based on the following points:
1. We need to ensure consistency in chip type judgment between compiling
and running states;
2. In compiling state, we need chip type to complete op's compilation,
but in running state, we only need device
type(910B/910_93/310P/910_95/etc) to make code branch judgement;
3. In compiling state, torch_npu may not have been installed yet, so we
can't use torch_npu's api.
Based on the above points, we have made the following changes:
1. When user set env `SOC_VERSION`, use it; when not set, query
soc_version by `npu-smi`;
2. generate device_type based on soc_version when compiling, and write
`__device_type__` instead of `__soc_version__` in `_build_info.py`;
3. In running state, use `__device_type__` to judge code branch.
### Does this PR introduce _any_ user-facing change?
When not set env `SOC_VERSION`, it will not be `ASCEND910B1` by default,
we will query soc_version by `npu-smi`. And env `SOC_VERSION` must be in
the list `soc_to_device` in `setup.py`.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -57,7 +57,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
get_ascend_device_type)
|
||||
|
||||
_ROUTER_SCALE = None
|
||||
|
||||
@@ -448,7 +449,8 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
# on 300I Duo platform, we find that num_voted_experts set to 5 achieves
|
||||
# good performance without sacrifice too much accuracy. for other platform,
|
||||
# this is set to 8 to use original pangu grouped topk.
|
||||
num_voted_experts = 5 if is_310p() else 8
|
||||
num_voted_experts = 5 if get_ascend_device_type(
|
||||
) == AscendDeviceType._310P else 8
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
@@ -1109,7 +1111,8 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
if is_310p() and "head" in name:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._310P and "head" in name:
|
||||
# on 300I Duo platform, ACL_FORMAT_FRACTAL_NZ is much more preferred than
|
||||
# ACL_FORMAT_FRACTAL_ND by matmul operation. Since lmhead is also implemented
|
||||
# by linear, we manually cast the format here.
|
||||
|
||||
@@ -28,9 +28,9 @@ def torchair_silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
|
||||
else:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
|
||||
@@ -51,8 +51,8 @@ from vllm_ascend.torchair.utils import (get_all_reduce_merge_state,
|
||||
get_rm_router_logits_state,
|
||||
npu_stream_switch, npu_wait_tensor,
|
||||
super_kernel)
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_ascend_soc_version, is_310p,
|
||||
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
|
||||
get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
@@ -75,11 +75,11 @@ def torchair_fused_experts_with_mc2(
|
||||
ep_world_size = moe_parallel_config.ep_size
|
||||
|
||||
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
need_extra_args = (get_ascend_device_type() == AscendDeviceType._910_93
|
||||
or is_torchair)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
a3_need_extra_args = get_ascend_device_type() == AscendDeviceType._910_93
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
@@ -467,7 +467,7 @@ def torchair_fused_experts_moge(
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
|
||||
@@ -57,9 +57,9 @@ def torchair_rmsnorm_forward_oot(
|
||||
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
if residual is not None:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
orig_dtype = residual.dtype
|
||||
x = x + residual.to(x.dtype)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
@@ -25,7 +25,8 @@ from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import enable_custom_op, is_310p
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
|
||||
get_ascend_device_type)
|
||||
|
||||
|
||||
def custom_rotary_embedding_enabled(query, neox_style, head_size):
|
||||
@@ -60,8 +61,9 @@ def rope_forward_oot(
|
||||
if is_neox_style_override is not None:
|
||||
neox_style = is_neox_style_override
|
||||
# adopt custom kernel path for rotary_embedding
|
||||
if custom_rotary_embedding_enabled(query, neox_style,
|
||||
self.head_size) and not is_310p():
|
||||
if custom_rotary_embedding_enabled(
|
||||
query, neox_style, self.head_size) and get_ascend_device_type(
|
||||
) != AscendDeviceType._310P:
|
||||
query, key = torch.ops._C_ascend.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
|
||||
@@ -28,8 +28,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.utils import (npu_stream_switch, npu_wait_tensor,
|
||||
super_kernel)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version,
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
dispose_tensor, get_ascend_device_type,
|
||||
is_enable_nz,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
@@ -234,11 +234,11 @@ def torchair_fused_experts_with_mc2(
|
||||
ep_world_size = ep_group.world_size
|
||||
|
||||
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
need_extra_args = (get_ascend_device_type() == AscendDeviceType._910_93
|
||||
or is_torchair)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
a3_need_extra_args = get_ascend_device_type() == AscendDeviceType._910_93
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
|
||||
@@ -34,8 +34,8 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
|
||||
aligned_16, get_ascend_device_type, nd_to_nz_2d)
|
||||
|
||||
|
||||
class AscendAttentionTorchairBackend(AscendAttentionBackend):
|
||||
@@ -185,7 +185,8 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._310P and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
|
||||
|
||||
@@ -381,7 +382,7 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl):
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# align q k v output tensors
|
||||
query = aligned_16(query)
|
||||
key = aligned_16(key)
|
||||
|
||||
@@ -42,8 +42,7 @@ from vllm_ascend.torchair.utils import (
|
||||
register_torchair_model, torchair_ops_patch,
|
||||
torchair_quant_method_register, write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_310p, get_ascend_soc_version,
|
||||
AscendSocVersion)
|
||||
AscendDeviceType, get_ascend_device_type)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
@@ -125,13 +124,13 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
max_num_tokens, tp_size)
|
||||
self.mc2_tokens_capacity = max_graph_batch_size
|
||||
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A3 and self.mc2_tokens_capacity > 512:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910_93 and self.mc2_tokens_capacity > 512:
|
||||
logger.error(
|
||||
f"A3: the max number of tokens must smaller then 512, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A2 and self.mc2_tokens_capacity > 256:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910B and self.mc2_tokens_capacity > 256:
|
||||
logger.error(
|
||||
f"A2: the max number of tokens must smaller then 256, but now is {self.mc2_tokens_capacity}"
|
||||
)
|
||||
@@ -207,7 +206,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
positions, attn_metadata, num_tokens,
|
||||
intermediate_tensors, inputs_embeds):
|
||||
if with_prefill or self.enable_shared_expert_dp:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
hidden_states = super()._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
@@ -230,7 +229,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
assert isinstance(kv, tuple), "kv_cache must be a tuple"
|
||||
torch._dynamo.mark_static(kv[0])
|
||||
torch._dynamo.mark_static(kv[1])
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(num_tokens)
|
||||
@@ -371,7 +370,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_NZ)
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
@@ -384,7 +383,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
@@ -414,7 +413,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_distributed import \
|
||||
@@ -428,7 +427,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.ascend_config.torchair_graph_config.enable_frozen_parameter
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.tiling_schedule_optimize = get_ascend_device_type(
|
||||
) != AscendDeviceType._310P
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
self.ascend_config.torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
@@ -531,8 +531,8 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
# NOTE: when enable_expert_parallel on A3, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
# Because we use x_active_mask for dispatch/combine op on A3, which requires that input shape should be same
|
||||
# on all EP ranks
|
||||
if get_ascend_soc_version(
|
||||
) == AscendSocVersion.A3 and self.parallel_config.enable_expert_parallel:
|
||||
if get_ascend_device_type(
|
||||
) == AscendDeviceType._910_93 and self.parallel_config.enable_expert_parallel:
|
||||
self._align_graph_size_divisible_by_tp_size()
|
||||
|
||||
def _align_graph_size_divisible_by_tp_size(self):
|
||||
|
||||
Reference in New Issue
Block a user