Support EPLB in FusedMoE (#8448)
This commit is contained in:
@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
|
||||
rank: int,
|
||||
):
|
||||
if server_args.expert_distribution_recorder_mode is not None:
|
||||
assert (
|
||||
expert_location_metadata is not None
|
||||
), "ExpertLocationMetadata is required for expert distribution recording. One possible"
|
||||
"reason is that you are using a model that does not support expert distribution"
|
||||
"recording. Try setting `get_model_config_for_expert_location` in your model."
|
||||
return _ExpertDistributionRecorderReal(
|
||||
server_args, expert_location_metadata, rank
|
||||
)
|
||||
|
||||
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
|
||||
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
|
||||
"""Trivial location - logical expert i corresponds to physical expert i"""
|
||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||
|
||||
if common is None:
|
||||
return None
|
||||
|
||||
num_physical_experts = common["num_physical_experts"]
|
||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||
num_layers = model_config_for_expert_location.num_layers
|
||||
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
|
||||
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
|
||||
|
||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||
|
||||
if common is None:
|
||||
return None
|
||||
|
||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
|
||||
physical_to_logical_map,
|
||||
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
|
||||
logical_count = logical_count.to(server_args.device)
|
||||
|
||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||
|
||||
if common is None:
|
||||
return None
|
||||
|
||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||
num_physical_experts = common["num_physical_experts"]
|
||||
num_groups = model_config_for_expert_location.num_groups
|
||||
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
|
||||
ModelConfigForExpertLocation.from_model_config(model_config)
|
||||
)
|
||||
|
||||
if model_config_for_expert_location is None:
|
||||
return None
|
||||
|
||||
num_physical_experts = (
|
||||
model_config_for_expert_location.num_logical_experts
|
||||
+ server_args.ep_num_redundant_experts
|
||||
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
|
||||
num_logical_experts: int
|
||||
num_groups: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def init_dummy():
|
||||
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
|
||||
|
||||
@staticmethod
|
||||
def from_model_config(model_config: ModelConfig):
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
|
||||
model_config.hf_config
|
||||
)
|
||||
else:
|
||||
return ModelConfigForExpertLocation.init_dummy()
|
||||
return None
|
||||
|
||||
|
||||
def compute_initial_expert_location_metadata(
|
||||
server_args: ServerArgs, model_config: ModelConfig
|
||||
) -> ExpertLocationMetadata:
|
||||
) -> Optional[ExpertLocationMetadata]:
|
||||
data = server_args.init_expert_location
|
||||
if data == "trivial":
|
||||
return ExpertLocationMetadata.init_trivial(server_args, model_config)
|
||||
|
||||
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
|
||||
def init_new(cls, layer_id: int):
|
||||
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
|
||||
expert_location_metadata = get_global_expert_location_metadata()
|
||||
assert expert_location_metadata is not None
|
||||
|
||||
if ep_dispatch_algorithm is None:
|
||||
return None
|
||||
|
||||
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||
assert old_expert_location_metadata is not None
|
||||
|
||||
_update_expert_weights(
|
||||
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
||||
old_expert_location_metadata=old_expert_location_metadata,
|
||||
|
||||
@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
num_fused_shared_experts: int = 0,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
top_k=top_k,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
layer_id=layer_id,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
physical_expert_ids = (
|
||||
get_global_expert_location_metadata().logical_to_all_physical(
|
||||
self.layer_id, expert_id
|
||||
global_expert_location_metadata = get_global_expert_location_metadata()
|
||||
if global_expert_location_metadata is None:
|
||||
self._weight_loader_impl(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
return
|
||||
|
||||
physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
|
||||
self.layer_id, expert_id
|
||||
)
|
||||
for physical_expert_id in physical_expert_ids:
|
||||
self._weight_loader_physical(
|
||||
@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
num_fused_shared_experts: int = 0,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
layer_id=layer_id,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
top_k: Optional[int] = None,
|
||||
layer_id: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.layer_id = layer_id
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.tp_size = (
|
||||
@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.num_experts = num_experts
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.expert_map = None
|
||||
|
||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||
@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
|
||||
global_expert_location_metadata = get_global_expert_location_metadata()
|
||||
if global_expert_location_metadata is None:
|
||||
self._weight_loader_impl(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
return
|
||||
|
||||
if expert_id >= self.num_experts - self.num_fused_shared_experts:
|
||||
# This is a shared expert.
|
||||
physical_expert_ids = [expert_id]
|
||||
else:
|
||||
physical_expert_ids = (
|
||||
global_expert_location_metadata.logical_to_all_physical(
|
||||
self.layer_id, expert_id
|
||||
)
|
||||
)
|
||||
|
||||
for physical_expert_id in physical_expert_ids:
|
||||
self._weight_loader_physical(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=physical_expert_id,
|
||||
)
|
||||
|
||||
def _weight_loader_physical(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
) -> None:
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
if disable_reason is not None:
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||
self.num_fused_shared_experts = 0
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
||||
|
||||
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||
global_server_args_dict["enable_deepep_moe"]
|
||||
or global_server_args_dict["enable_ep_moe"]
|
||||
):
|
||||
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
||||
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
||||
|
||||
if disable_reason is not None:
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||
self.num_fused_shared_experts = 0
|
||||
log_info_on_rank0(
|
||||
logger,
|
||||
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
||||
|
||||
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
layer_id=layer_id,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=True,
|
||||
quant_config=quant_config,
|
||||
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.block_sparse_moe",
|
||||
)
|
||||
|
||||
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_id: int,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
|
||||
self.experts = MoEImpl(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
layer_id=layer_id,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
|
||||
)
|
||||
self.block_sparse_moe = Grok1MoE(
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
|
||||
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=False,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
|
||||
num_experts=config.num_local_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
layer_id=layer_id,
|
||||
reduce_results=False,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=True,
|
||||
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
if is_moe_layer:
|
||||
self.feed_forward = Llama4MoE(
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("feed_forward", prefix),
|
||||
)
|
||||
|
||||
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
|
||||
self.experts = MoEImpl(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
layer_id=layer_id,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("block_sparse_moe", prefix),
|
||||
)
|
||||
|
||||
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
layer_id: int = 0,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
|
||||
reduce_results=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
layer_id=layer_id,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
layer_id=layer_id,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
|
||||
@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
layer_id=layer_id,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=True,
|
||||
|
||||
Reference in New Issue
Block a user