diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index da5e8e3..aa18942 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1048,7 +1048,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, shared_experts=shared_experts) - elif fused_moe_state == FusedMoEState.AllGather: + elif fused_moe_state in [ + FusedMoEState.AllGather, FusedMoEState.NaiveMulticast + ]: return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1225,6 +1227,22 @@ class AscendFusedMoE(FusedMoE): self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) + def naive_multicast(self, x: torch.Tensor, + cu_tokens_across_dp_cpu: torch.Tensor): + assert (len(x.shape) == 2) + buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), + device=x.device, + dtype=x.dtype) + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + buffer[start:end, :].copy_(x) + for idx in range(self.dp_size): + start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] + end = cu_tokens_across_dp_cpu[idx] + get_dp_group().broadcast(buffer[start:end, :], idx) + return buffer + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -1250,9 +1268,10 @@ class AscendFusedMoE(FusedMoE): shared_hidden_states = shared_experts(hidden_states) tp_size = get_tensor_model_parallel_world_size() - if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather - and fused_moe_state != FusedMoEState.AllGatherEP - and not replace_allreduce): + if (tp_size > 1 and fused_moe_state not in [ + FusedMoEState.AllGather, FusedMoEState.AllGatherEP, + FusedMoEState.NaiveMulticast + ] and not replace_allreduce): if num_tokens < tp_size: hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, tp_size - num_tokens)) @@ -1267,21 +1286,31 @@ class AscendFusedMoE(FusedMoE): tp_rank = get_tensor_model_parallel_rank() hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] - if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: - # NOTE: When in torchair graph, it has been padded in model_runner_v1 - if not self.torchair_graph_enabled or is_prefill: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: - max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp - if num_tokens < max_num_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) - hidden_states = get_dp_group().all_gather(hidden_states, 0) - router_logits = get_dp_group().all_gather(router_logits, 0) + if self.dp_size > 1: + if fused_moe_state == FusedMoEState.AllGather: + # NOTE: When in torchair graph, it has been padded in model_runner_v1 + if not self.torchair_graph_enabled: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: + max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp + if num_tokens < max_num_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, + max_num_tokens_across_dp - num_tokens)) + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, + max_num_tokens_across_dp - num_tokens)) + hidden_states = get_dp_group().all_gather(hidden_states, 0) + router_logits = get_dp_group().all_gather(router_logits, 0) + elif fused_moe_state == FusedMoEState.NaiveMulticast: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + hidden_states = self.naive_multicast(hidden_states, + cu_tokens_across_dp_cpu) + router_logits = self.naive_multicast(router_logits, + cu_tokens_across_dp_cpu) # Matrix multiply. e_hidden_states = self.quant_method.apply( @@ -1310,28 +1339,40 @@ class AscendFusedMoE(FusedMoE): if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states - if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather - and fused_moe_state != FusedMoEState.AllGatherEP - and not replace_allreduce): + if (tp_size > 1 and fused_moe_state not in [ + FusedMoEState.AllGather, FusedMoEState.AllGatherEP, + FusedMoEState.NaiveMulticast + ] and not replace_allreduce): dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) if num_tokens < tp_size: final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) - elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( - e_hidden_states, - "sum", - scatter_dim=0, - group=get_dp_group().device_group) - final_hidden_states = final_hidden_states[:num_tokens] - dispose_tensor(e_hidden_states) + elif self.dp_size > 1: + if fused_moe_state == FusedMoEState.NaiveMulticast: + start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ + self.dp_rank - 1] + end = cu_tokens_across_dp_cpu[self.dp_rank] + final_hidden_states = get_dp_group().all_reduce( + e_hidden_states) + final_hidden_states = final_hidden_states[start:end, :] + dispose_tensor(e_hidden_states) + elif fused_moe_state == FusedMoEState.AllGather: + final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + e_hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + final_hidden_states = final_hidden_states[:num_tokens] + dispose_tensor(e_hidden_states) else: final_hidden_states = e_hidden_states - if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather - or fused_moe_state == FusedMoEState.AllGatherEP): + if tp_size > 1 and fused_moe_state in [ + FusedMoEState.AllGather, FusedMoEState.AllGatherEP, + FusedMoEState.NaiveMulticast + ]: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 5bfecaa..a0c90ab 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -780,7 +780,9 @@ class AscendW8A8DynamicFusedMoEMethod: log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts) - elif fused_moe_state == FusedMoEState.AllGather: + elif fused_moe_state in [ + FusedMoEState.AllGather, FusedMoEState.NaiveMulticast + ]: return fused_experts(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d02c2ec..250e785 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -419,6 +419,7 @@ class FusedMoEState(Enum): All2All = 1 MC2 = 2 AllGatherEP = 3 + NaiveMulticast = 4 # TODO(zzzzwwjj): add soc_version to choose branch @@ -430,7 +431,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool, and is_deepseek_v3_r1): return FusedMoEState.AllGatherEP elif ep_size == 1: - return FusedMoEState.AllGather + if with_prefill: + return FusedMoEState.NaiveMulticast + else: + return FusedMoEState.AllGather # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: return FusedMoEState.All2All