Minor code cleanup refactor for DeepSeek models (#6324)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
try:
|
||||
from deep_gemm import (
|
||||
@@ -40,7 +41,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -1173,3 +1174,11 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
|
||||
return down_output
|
||||
|
||||
|
||||
def get_moe_impl_class():
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
return DeepEPMoE
|
||||
if global_server_args_dict["enable_ep_moe"]:
|
||||
return EPMoE
|
||||
return FusedMoE
|
||||
|
||||
@@ -52,7 +52,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
@@ -222,13 +222,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
||||
|
||||
MoEImpl = (
|
||||
DeepEPMoE
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
|
||||
self.experts = MoEImpl(
|
||||
self.experts = get_moe_impl_class()(
|
||||
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
||||
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -251,26 +245,19 @@ class DeepseekV2MoE(nn.Module):
|
||||
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
# disable tp for shared experts when enable deepep moe
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
)
|
||||
else:
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
tp_rank=0,
|
||||
tp_size=1,
|
||||
)
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
**(
|
||||
dict(tp_rank=0, tp_size=1)
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
# TODO: we will support tp < ep in the future
|
||||
@@ -1726,12 +1713,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
MoEImpl = (
|
||||
DeepEPMoE
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
Reference in New Issue
Block a user