[Feature] Support DeepEP Low Latency (#4767)
Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -188,19 +188,35 @@ class DeepseekV2MoE(nn.Module):
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
else:
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||
)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
@@ -227,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_tensor_model_parallel_world_size()
|
||||
self.num_experts = config.n_routed_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.renormalize = config.norm_topk_prob
|
||||
@@ -246,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||
async_finish=True, # TODO
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -301,28 +321,39 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
recv_hidden_states, reorder_topk_ids, seg_indptr = (
|
||||
self.deepep_dispatcher.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.num_experts,
|
||||
forward_mode,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
) = self.deepep_dispatcher.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.num_experts,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
final_hidden_states = (
|
||||
self.experts(
|
||||
hidden_states=recv_hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
* self.routed_scaling_factor
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
if self.ep_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
final_hidden_states, forward_mode
|
||||
final_hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
forward_mode,
|
||||
)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
Reference in New Issue
Block a user