support eplb for qwen3 (#6533)
This commit is contained in:
@@ -65,6 +65,7 @@ def fused_topk(
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -88,7 +89,7 @@ def fused_topk(
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@@ -355,12 +356,13 @@ def select_experts(
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk"
|
||||
assert expert_location_dispatch_info is None
|
||||
# Qwen3MOE uses fused_topk
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
|
||||
@@ -690,7 +690,9 @@ def _convert_global_physical_count_to_logical_count(
|
||||
)
|
||||
logical_count.scatter_add_(
|
||||
dim=2,
|
||||
index=physical_to_logical_map.unsqueeze(0).expand(dim_extra, -1, -1),
|
||||
index=physical_to_logical_map.unsqueeze(0)
|
||||
.expand(dim_extra, -1, -1)
|
||||
.to(torch.int64),
|
||||
src=global_physical_count,
|
||||
)
|
||||
return logical_count
|
||||
|
||||
@@ -55,7 +55,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
@@ -67,6 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
@@ -86,28 +88,25 @@ logger = logging.getLogger(__name__)
|
||||
class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
config: Qwen3MoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.layer_id = layer_id
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}."
|
||||
)
|
||||
|
||||
MoEImpl = (
|
||||
DeepEPMoE
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.num_experts,
|
||||
self.experts = get_moe_impl_class()(
|
||||
num_experts=config.num_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
top_k=config.num_experts_per_tok,
|
||||
layer_id=layer_id,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
@@ -131,7 +130,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_tensor_model_parallel_world_size()
|
||||
self.num_experts = config.num_experts
|
||||
self.num_experts = (
|
||||
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.renormalize = config.norm_topk_prob
|
||||
|
||||
@@ -139,7 +140,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=config.num_experts,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=config.num_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
@@ -157,8 +158,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_mode)
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
x.data
|
||||
for name, x in self.experts.named_parameters()
|
||||
if name not in ["correction_bias"]
|
||||
]
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
@@ -189,6 +196,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=self.renormalize,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
topk_idx = torch.full(
|
||||
@@ -408,6 +418,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
|
||||
if self.info.is_sparse:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(
|
||||
layer_id=self.layer_id,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
@@ -685,15 +696,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
MoEImpl = (
|
||||
DeepEPMoE
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
@@ -770,5 +773,19 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
self.routed_experts_weights_of_layer = {
|
||||
layer_id: layer.mlp.get_moe_weights()
|
||||
for layer_id, layer in enumerate(self.model.layers)
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_model_config_for_expert_location(cls, config):
|
||||
return ModelConfigForExpertLocation(
|
||||
num_layers=config.num_hidden_layers,
|
||||
num_logical_experts=config.num_experts,
|
||||
num_groups=None,
|
||||
)
|
||||
|
||||
|
||||
EntryClass = Qwen3MoeForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user