Co-authored-by: Stefan He <hebiaobuaa@gmail.com> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -387,6 +387,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
):
|
):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.ep_size = get_moe_expert_parallel_world_size()
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.n_shared_experts = config.n_shared_experts
|
self.n_shared_experts = config.n_shared_experts
|
||||||
self.num_fused_shared_experts = (
|
self.num_fused_shared_experts = (
|
||||||
@@ -480,11 +481,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
prefix=add_prefix("shared_experts", prefix),
|
prefix=add_prefix("shared_experts", prefix),
|
||||||
**(
|
**(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}),
|
||||||
dict(tp_rank=0, tp_size=1)
|
|
||||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
is_packed_weight = hasattr(
|
is_packed_weight = hasattr(
|
||||||
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
self.shared_experts.gate_up_proj.quant_method, "quant_config"
|
||||||
@@ -531,6 +528,77 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
|
|
||||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||||
|
|
||||||
|
def forward_normal_dual_stream(
|
||||||
|
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
|
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
kwargs = {"hidden_states": hidden_states}
|
||||||
|
if self.topk is not None:
|
||||||
|
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||||
|
else:
|
||||||
|
kwargs["router_logits"] = router_logits
|
||||||
|
final_hidden_states = self.experts(**kwargs)
|
||||||
|
if not _is_cuda:
|
||||||
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
|
||||||
|
if self.ep_size > 1:
|
||||||
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states
|
||||||
|
)
|
||||||
|
final_hidden_states += shared_output
|
||||||
|
else:
|
||||||
|
final_hidden_states += shared_output
|
||||||
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states
|
||||||
|
)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
def forward_normal(
|
||||||
|
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
|
self.shared_experts.gate_up_proj
|
||||||
|
):
|
||||||
|
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
||||||
|
|
||||||
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
|
kwargs = {"hidden_states": hidden_states}
|
||||||
|
if self.topk is not None:
|
||||||
|
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||||
|
else:
|
||||||
|
kwargs["router_logits"] = router_logits
|
||||||
|
final_hidden_states = self.experts(**kwargs)
|
||||||
|
if not _is_cuda and not _use_aiter:
|
||||||
|
# fused in biased_grouped_topk so we can skip here
|
||||||
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
if self.ep_size > 1:
|
||||||
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states
|
||||||
|
)
|
||||||
|
if shared_output is not None:
|
||||||
|
final_hidden_states += shared_output
|
||||||
|
else:
|
||||||
|
if shared_output is not None:
|
||||||
|
final_hidden_states += shared_output
|
||||||
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
final_hidden_states
|
||||||
|
)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user