Co-authored-by: Stefan He <hebiaobuaa@gmail.com> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -387,6 +387,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.num_fused_shared_experts = (
|
||||
@@ -480,11 +481,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
**(
|
||||
dict(tp_rank=0, tp_size=1)
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
**(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}),
|
||||
)
|
||||
is_packed_weight = hasattr(
|
||||
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
||||
@@ -531,6 +528,77 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
|
||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
|
||||
def forward_normal_dual_stream(
|
||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||
) -> torch.Tensor:
|
||||
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
|
||||
if self.ep_size > 1:
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states
|
||||
)
|
||||
final_hidden_states += shared_output
|
||||
else:
|
||||
final_hidden_states += shared_output
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
def forward_normal(
|
||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||
) -> torch.Tensor:
|
||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||
self.shared_experts.gate_up_proj
|
||||
):
|
||||
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
||||
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
if self.topk is not None:
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
kwargs["router_logits"] = router_logits
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
if self.ep_size > 1:
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states
|
||||
)
|
||||
if shared_output is not None:
|
||||
final_hidden_states += shared_output
|
||||
else:
|
||||
if shared_output is not None:
|
||||
final_hidden_states += shared_output
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user