Support EPLB in FusedMoE (#8448)
This commit is contained in:
@@ -47,6 +47,11 @@ class ExpertDistributionRecorder(ABC):
|
|||||||
rank: int,
|
rank: int,
|
||||||
):
|
):
|
||||||
if server_args.expert_distribution_recorder_mode is not None:
|
if server_args.expert_distribution_recorder_mode is not None:
|
||||||
|
assert (
|
||||||
|
expert_location_metadata is not None
|
||||||
|
), "ExpertLocationMetadata is required for expert distribution recording. One possible"
|
||||||
|
"reason is that you are using a model that does not support expert distribution"
|
||||||
|
"recording. Try setting `get_model_config_for_expert_location` in your model."
|
||||||
return _ExpertDistributionRecorderReal(
|
return _ExpertDistributionRecorderReal(
|
||||||
server_args, expert_location_metadata, rank
|
server_args, expert_location_metadata, rank
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -82,6 +82,10 @@ class ExpertLocationMetadata:
|
|||||||
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
|
def init_trivial(server_args: ServerArgs, model_config: ModelConfig):
|
||||||
"""Trivial location - logical expert i corresponds to physical expert i"""
|
"""Trivial location - logical expert i corresponds to physical expert i"""
|
||||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||||
|
|
||||||
|
if common is None:
|
||||||
|
return None
|
||||||
|
|
||||||
num_physical_experts = common["num_physical_experts"]
|
num_physical_experts = common["num_physical_experts"]
|
||||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||||
num_layers = model_config_for_expert_location.num_layers
|
num_layers = model_config_for_expert_location.num_layers
|
||||||
@@ -109,6 +113,10 @@ class ExpertLocationMetadata:
|
|||||||
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
|
physical_to_logical_map = physical_to_logical_map.to(server_args.device)
|
||||||
|
|
||||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||||
|
|
||||||
|
if common is None:
|
||||||
|
return None
|
||||||
|
|
||||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||||
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
|
logical_to_all_physical_map = _compute_logical_to_all_physical_map(
|
||||||
physical_to_logical_map,
|
physical_to_logical_map,
|
||||||
@@ -133,6 +141,10 @@ class ExpertLocationMetadata:
|
|||||||
logical_count = logical_count.to(server_args.device)
|
logical_count = logical_count.to(server_args.device)
|
||||||
|
|
||||||
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
common = ExpertLocationMetadata._init_common(server_args, model_config)
|
||||||
|
|
||||||
|
if common is None:
|
||||||
|
return None
|
||||||
|
|
||||||
model_config_for_expert_location = common["model_config_for_expert_location"]
|
model_config_for_expert_location = common["model_config_for_expert_location"]
|
||||||
num_physical_experts = common["num_physical_experts"]
|
num_physical_experts = common["num_physical_experts"]
|
||||||
num_groups = model_config_for_expert_location.num_groups
|
num_groups = model_config_for_expert_location.num_groups
|
||||||
@@ -168,6 +180,9 @@ class ExpertLocationMetadata:
|
|||||||
ModelConfigForExpertLocation.from_model_config(model_config)
|
ModelConfigForExpertLocation.from_model_config(model_config)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_config_for_expert_location is None:
|
||||||
|
return None
|
||||||
|
|
||||||
num_physical_experts = (
|
num_physical_experts = (
|
||||||
model_config_for_expert_location.num_logical_experts
|
model_config_for_expert_location.num_logical_experts
|
||||||
+ server_args.ep_num_redundant_experts
|
+ server_args.ep_num_redundant_experts
|
||||||
@@ -398,10 +413,6 @@ class ModelConfigForExpertLocation:
|
|||||||
num_logical_experts: int
|
num_logical_experts: int
|
||||||
num_groups: Optional[int] = None
|
num_groups: Optional[int] = None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def init_dummy():
|
|
||||||
return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_config(model_config: ModelConfig):
|
def from_model_config(model_config: ModelConfig):
|
||||||
model_class, _ = get_model_architecture(model_config)
|
model_class, _ = get_model_architecture(model_config)
|
||||||
@@ -410,12 +421,12 @@ class ModelConfigForExpertLocation:
|
|||||||
model_config.hf_config
|
model_config.hf_config
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ModelConfigForExpertLocation.init_dummy()
|
return None
|
||||||
|
|
||||||
|
|
||||||
def compute_initial_expert_location_metadata(
|
def compute_initial_expert_location_metadata(
|
||||||
server_args: ServerArgs, model_config: ModelConfig
|
server_args: ServerArgs, model_config: ModelConfig
|
||||||
) -> ExpertLocationMetadata:
|
) -> Optional[ExpertLocationMetadata]:
|
||||||
data = server_args.init_expert_location
|
data = server_args.init_expert_location
|
||||||
if data == "trivial":
|
if data == "trivial":
|
||||||
return ExpertLocationMetadata.init_trivial(server_args, model_config)
|
return ExpertLocationMetadata.init_trivial(server_args, model_config)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class ExpertLocationDispatchInfo:
|
|||||||
def init_new(cls, layer_id: int):
|
def init_new(cls, layer_id: int):
|
||||||
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
|
ep_dispatch_algorithm = global_server_args_dict["ep_dispatch_algorithm"]
|
||||||
expert_location_metadata = get_global_expert_location_metadata()
|
expert_location_metadata = get_global_expert_location_metadata()
|
||||||
|
assert expert_location_metadata is not None
|
||||||
|
|
||||||
if ep_dispatch_algorithm is None:
|
if ep_dispatch_algorithm is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -50,6 +50,8 @@ class ExpertLocationUpdater:
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
old_expert_location_metadata = get_global_expert_location_metadata()
|
old_expert_location_metadata = get_global_expert_location_metadata()
|
||||||
|
assert old_expert_location_metadata is not None
|
||||||
|
|
||||||
_update_expert_weights(
|
_update_expert_weights(
|
||||||
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
routed_experts_weights_of_layer=routed_experts_weights_of_layer,
|
||||||
old_expert_location_metadata=old_expert_location_metadata,
|
old_expert_location_metadata=old_expert_location_metadata,
|
||||||
|
|||||||
@@ -183,6 +183,7 @@ class EPMoE(FusedMoE):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
@@ -196,6 +197,7 @@ class EPMoE(FusedMoE):
|
|||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@@ -728,10 +730,19 @@ class EPMoE(FusedMoE):
|
|||||||
shard_id: str,
|
shard_id: str,
|
||||||
expert_id: int,
|
expert_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
physical_expert_ids = (
|
global_expert_location_metadata = get_global_expert_location_metadata()
|
||||||
get_global_expert_location_metadata().logical_to_all_physical(
|
if global_expert_location_metadata is None:
|
||||||
self.layer_id, expert_id
|
self._weight_loader_impl(
|
||||||
|
param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
weight_name=weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
physical_expert_ids = global_expert_location_metadata.logical_to_all_physical(
|
||||||
|
self.layer_id, expert_id
|
||||||
)
|
)
|
||||||
for physical_expert_id in physical_expert_ids:
|
for physical_expert_id in physical_expert_ids:
|
||||||
self._weight_loader_physical(
|
self._weight_loader_physical(
|
||||||
@@ -778,6 +789,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
@@ -792,6 +804,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
@@ -62,8 +63,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
num_experts: int,
|
num_experts: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
|
layer_id: int,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
layer_id: Optional[int] = None,
|
num_fused_shared_experts: int = 0,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
reduce_results: bool = False,
|
reduce_results: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
@@ -84,6 +86,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
self.layer_id = layer_id
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.tp_size = (
|
self.tp_size = (
|
||||||
@@ -91,6 +94,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
|
self.num_fused_shared_experts = num_fused_shared_experts
|
||||||
self.expert_map = None
|
self.expert_map = None
|
||||||
|
|
||||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||||
@@ -375,6 +379,45 @@ class FusedMoE(torch.nn.Module):
|
|||||||
shard_id: str,
|
shard_id: str,
|
||||||
expert_id: int,
|
expert_id: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
|
global_expert_location_metadata = get_global_expert_location_metadata()
|
||||||
|
if global_expert_location_metadata is None:
|
||||||
|
self._weight_loader_impl(
|
||||||
|
param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
weight_name=weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if expert_id >= self.num_experts - self.num_fused_shared_experts:
|
||||||
|
# This is a shared expert.
|
||||||
|
physical_expert_ids = [expert_id]
|
||||||
|
else:
|
||||||
|
physical_expert_ids = (
|
||||||
|
global_expert_location_metadata.logical_to_all_physical(
|
||||||
|
self.layer_id, expert_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for physical_expert_id in physical_expert_ids:
|
||||||
|
self._weight_loader_physical(
|
||||||
|
param=param,
|
||||||
|
loaded_weight=loaded_weight,
|
||||||
|
weight_name=weight_name,
|
||||||
|
shard_id=shard_id,
|
||||||
|
expert_id=physical_expert_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _weight_loader_physical(
|
||||||
|
self,
|
||||||
|
param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str,
|
||||||
|
shard_id: str,
|
||||||
|
expert_id: int,
|
||||||
|
) -> None:
|
||||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||||
if expert_id == -1:
|
if expert_id == -1:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -325,6 +325,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_experts=config.n_routed_experts
|
num_experts=config.n_routed_experts
|
||||||
+ self.num_fused_shared_experts
|
+ self.num_fused_shared_experts
|
||||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
@@ -2112,6 +2113,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
if disable_reason is not None:
|
if disable_reason is not None:
|
||||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||||
|
self.num_fused_shared_experts = 0
|
||||||
log_info_on_rank0(
|
log_info_on_rank0(
|
||||||
logger,
|
logger,
|
||||||
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
||||||
|
|||||||
@@ -434,6 +434,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
num_experts=config.n_routed_experts
|
num_experts=config.n_routed_experts
|
||||||
+ self.num_fused_shared_experts
|
+ self.num_fused_shared_experts
|
||||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
@@ -740,10 +741,11 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|||||||
global_server_args_dict["enable_deepep_moe"]
|
global_server_args_dict["enable_deepep_moe"]
|
||||||
or global_server_args_dict["enable_ep_moe"]
|
or global_server_args_dict["enable_ep_moe"]
|
||||||
):
|
):
|
||||||
disable_reason = "Deepseek GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
||||||
|
|
||||||
if disable_reason is not None:
|
if disable_reason is not None:
|
||||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||||
|
self.num_fused_shared_experts = 0
|
||||||
log_info_on_rank0(
|
log_info_on_rank0(
|
||||||
logger,
|
logger,
|
||||||
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class GraniteMoeMoE(nn.Module):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
|
layer_id: int,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
@@ -71,6 +72,7 @@ class GraniteMoeMoE(nn.Module):
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
|
layer_id=layer_id,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@@ -203,6 +205,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
|||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
|
layer_id=layer_id,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.block_sparse_moe",
|
prefix=f"{prefix}.block_sparse_moe",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class Grok1MoE(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
layer_id: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -128,6 +129,7 @@ class Grok1MoE(nn.Module):
|
|||||||
self.experts = MoEImpl(
|
self.experts = MoEImpl(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
layer_id=layer_id,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
@@ -331,6 +333,7 @@ class Grok1DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
self.block_sparse_moe = Grok1MoE(
|
self.block_sparse_moe = Grok1MoE(
|
||||||
config=config,
|
config=config,
|
||||||
|
layer_id=layer_id,
|
||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
|
|||||||
@@ -163,6 +163,7 @@ class HunYuanSparseMoeBlock(nn.Module):
|
|||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
|
layer_id=layer_id,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class Llama4MoE(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Llama4TextConfig,
|
config: Llama4TextConfig,
|
||||||
|
layer_id: int,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
@@ -114,6 +115,7 @@ class Llama4MoE(nn.Module):
|
|||||||
num_experts=config.num_local_experts,
|
num_experts=config.num_local_experts,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=intermediate_size_moe,
|
intermediate_size=intermediate_size_moe,
|
||||||
|
layer_id=layer_id,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
apply_router_weight_on_input=True,
|
apply_router_weight_on_input=True,
|
||||||
@@ -373,6 +375,7 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
if is_moe_layer:
|
if is_moe_layer:
|
||||||
self.feed_forward = Llama4MoE(
|
self.feed_forward = Llama4MoE(
|
||||||
config=config,
|
config=config,
|
||||||
|
layer_id=layer_id,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("feed_forward", prefix),
|
prefix=add_prefix("feed_forward", prefix),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class MixtralMoE(nn.Module):
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
|
layer_id: int,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
@@ -97,6 +98,7 @@ class MixtralMoE(nn.Module):
|
|||||||
self.experts = MoEImpl(
|
self.experts = MoEImpl(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
layer_id=layer_id,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
@@ -226,6 +228,7 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
|
layer_id=layer_id,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("block_sparse_moe", prefix),
|
prefix=add_prefix("block_sparse_moe", prefix),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class OlmoeMoE(nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
|
layer_id: int = 0,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -89,6 +90,7 @@ class OlmoeMoE(nn.Module):
|
|||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
tp_size=tp_size,
|
tp_size=tp_size,
|
||||||
|
layer_id=layer_id,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -224,6 +226,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
|
layer_id=layer_id,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("mlp", prefix),
|
prefix=add_prefix("mlp", prefix),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ class PhiMoE(nn.Module):
|
|||||||
self.experts = FusedMoE(
|
self.experts = FusedMoE(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
layer_id=layer_id,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size=intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user