Minor code cleanup refactor for DeepSeek models (#6324)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_gemm import (
|
from deep_gemm import (
|
||||||
@@ -40,7 +41,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
tma_align_input_scale,
|
tma_align_input_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
@@ -1173,3 +1174,11 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return down_output
|
return down_output
|
||||||
|
|
||||||
|
|
||||||
|
def get_moe_impl_class():
|
||||||
|
if global_server_args_dict["enable_deepep_moe"]:
|
||||||
|
return DeepEPMoE
|
||||||
|
if global_server_args_dict["enable_ep_moe"]:
|
||||||
|
return EPMoE
|
||||||
|
return FusedMoE
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
@@ -222,13 +222,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
||||||
|
|
||||||
MoEImpl = (
|
self.experts = get_moe_impl_class()(
|
||||||
DeepEPMoE
|
|
||||||
if global_server_args_dict["enable_deepep_moe"]
|
|
||||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.experts = MoEImpl(
|
|
||||||
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
||||||
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -251,26 +245,19 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||||
# disable tp for shared experts when enable deepep moe
|
# disable tp for shared experts when enable deepep moe
|
||||||
if not global_server_args_dict["enable_deepep_moe"]:
|
self.shared_experts = DeepseekV2MLP(
|
||||||
self.shared_experts = DeepseekV2MLP(
|
hidden_size=config.hidden_size,
|
||||||
hidden_size=config.hidden_size,
|
intermediate_size=intermediate_size,
|
||||||
intermediate_size=intermediate_size,
|
hidden_act=config.hidden_act,
|
||||||
hidden_act=config.hidden_act,
|
quant_config=quant_config,
|
||||||
quant_config=quant_config,
|
reduce_results=False,
|
||||||
reduce_results=False,
|
prefix=add_prefix("shared_experts", prefix),
|
||||||
prefix=add_prefix("shared_experts", prefix),
|
**(
|
||||||
)
|
dict(tp_rank=0, tp_size=1)
|
||||||
else:
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
self.shared_experts = DeepseekV2MLP(
|
else {}
|
||||||
hidden_size=config.hidden_size,
|
),
|
||||||
intermediate_size=intermediate_size,
|
)
|
||||||
hidden_act=config.hidden_act,
|
|
||||||
quant_config=quant_config,
|
|
||||||
reduce_results=False,
|
|
||||||
prefix=add_prefix("shared_experts", prefix),
|
|
||||||
tp_rank=0,
|
|
||||||
tp_size=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if global_server_args_dict["enable_deepep_moe"]:
|
if global_server_args_dict["enable_deepep_moe"]:
|
||||||
# TODO: we will support tp < ep in the future
|
# TODO: we will support tp < ep in the future
|
||||||
@@ -1726,12 +1713,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
MoEImpl = (
|
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||||
DeepEPMoE
|
|
||||||
if global_server_args_dict["enable_deepep_moe"]
|
|
||||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
|
||||||
)
|
|
||||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
|
|||||||
Reference in New Issue
Block a user