init v0.11.0rc0

This commit is contained in:
2025-10-14 10:38:28 +08:00
parent 67afd0ea78
commit 66dc16f966
278 changed files with 28130 additions and 11708 deletions

View File

@@ -40,17 +40,18 @@ from vllm.model_executor.layers.quantization.base_config import \
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
get_ascend_soc_version,
get_rm_router_logits_state, is_310p)
get_rm_router_logits_state, is_310p,
vllm_version_is)
def torchair_fused_experts_with_mc2(
@@ -802,6 +803,7 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
try:
device_group = get_mc2_group().device_group
@@ -883,6 +885,8 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_forward_context().fused_moe_state
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
if fused_moe_state == FusedMoEState.MC2:
return torchair_fused_experts_with_mc2(
@@ -1013,45 +1017,70 @@ class TorchairAscendFusedMoE(FusedMoE):
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
ascend_config = get_ascend_config()
expert_map_path = ascend_config.expert_map_path
if expert_map_path and os.path.exists(expert_map_path):
# moe expert load balance
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
self.global_num_experts)
self.local_num_experts, self.expert_map = \
expert_load_balancer.get_rank_placement_map(
self.moe_instance_id,
get_ep_group().rank_in_group)
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id,
get_ep_group().rank_in_group)
self.global_redundant_expert_num = \
expert_load_balancer.get_global_redundant_expert_num()
self.dynamic_eplb = ascend_config.dynamic_eplb
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
# static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path,
os.R_OK):
self.expert_load_balancer = ExpertLoadBalancer(
self.expert_map_path, self.global_num_experts)
self.local_num_experts, self.expert_map = (
self.expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank))
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id, self.ep_rank).npu()
self.global_redundant_expert_num = (
self.expert_load_balancer.get_global_redundant_expert_num())
else:
# Create a tensor of size num_experts filled with -1
# init moe.
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)
self.ep_size, self.ep_rank, self.global_num_experts)
# dynamic eplb initializing with not expert_map_path
if self.dynamic_eplb:
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.local_num_experts, self.expert_map = determine_default_expert_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
self.log2phy = determine_default_log2phy_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
local_num_experts = (torch.sum(self.expert_map != -1)
if self.expert_map is not None else num_experts)
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe and \
self.multistream_overlap_shared_expert = \
ascend_config.multistream_overlap_shared_expert and \
self.torchair_graph_enabled
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
self.moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_config=quant_config)
if vllm_version_is("0.10.2"):
self.moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_config=quant_config)
else:
self.moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=params_dtype,
)
if quant_config is None:
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
self.moe)
@@ -1066,8 +1095,11 @@ class TorchairAscendFusedMoE(FusedMoE):
assert self.quant_method is not None
local_num_experts = torch.sum(self.expert_map != -1) \
if self.expert_map is not None else num_experts
self.moe_load = None
local_num_experts = (torch.sum(self.expert_map != -1)
if self.expert_map is not None else num_experts)
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
moe_quant_params = {
"num_experts": local_num_experts,
@@ -1126,23 +1158,25 @@ class TorchairAscendFusedMoE(FusedMoE):
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
if self.enable_shared_expert_dp and fused_moe_state == FusedMoEState.MC2:
fused_moe_state = FusedMoEState.All2All
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
if self.enable_multistream_moe:
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicFusedMoEMethod
if self.multistream_overlap_shared_expert:
if not self.rm_router_logits:
router_logits, _ = gate(hidden_states)
if hasattr(self.quant_method, "quant_method") and \
isinstance(self.quant_method.quant_method,
AscendW8A8DynamicFusedMoEMethod
TorchairAscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
if not self.multistream_overlap_shared_expert or fused_moe_state != FusedMoEState.MC2:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)
@@ -1160,31 +1194,33 @@ class TorchairAscendFusedMoE(FusedMoE):
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, padding_size - num_tokens))
]):
if tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
if not self.enable_shared_expert_dp:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
if not replace_allreduce:
if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
if not self.enable_shared_expert_dp:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
@@ -1206,8 +1242,12 @@ class TorchairAscendFusedMoE(FusedMoE):
router_logits = get_dp_group().all_gather(router_logits, 0)
elif fused_moe_state == FusedMoEState.NaiveMulticast:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
if vllm_version_is("0.10.2"):
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
else:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_sp(1)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
if self.rm_router_logits:
@@ -1236,7 +1276,8 @@ class TorchairAscendFusedMoE(FusedMoE):
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=shared_experts if self.torchair_graph_enabled
and self.enable_multistream_moe and not is_prefill else None,
and self.multistream_overlap_shared_expert and not is_prefill else
None,
mc2_mask=mc2_mask,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
@@ -1246,6 +1287,11 @@ class TorchairAscendFusedMoE(FusedMoE):
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if self.dynamic_eplb and isinstance(
e_hidden_states, tuple) and len(e_hidden_states) == 3:
self.moe_load += e_hidden_states[2] if e_hidden_states[1] == 0 else \
torch.cat(e_hidden_states[2][:1], e_hidden_states[2][1:] - e_hidden_states[2][:-1])
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
@@ -1269,8 +1315,8 @@ class TorchairAscendFusedMoE(FusedMoE):
final_hidden_states = final_hidden_states[start:end, :]
dispose_tensor(e_hidden_states)
elif fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = data_parallel_reduce_scatter(
e_hidden_states, dim=0)
final_hidden_states = get_dp_group().reduce_scatter(
e_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
else:
@@ -1290,6 +1336,19 @@ class TorchairAscendFusedMoE(FusedMoE):
else:
return final_hidden_states
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
def get_map(self):
return self.expert_map
def get_log2phy_map(self):
return self.logical_to_physical_map
def clear_moe_load(self):
if self.moe_load is not None:
self.moe_load.zero_()
# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_fused_moe_comp(