support eplb for qwen3 (#6533)
This commit is contained in:
@@ -65,6 +65,7 @@ def fused_topk(
|
|||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
):
|
):
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
@@ -88,7 +89,7 @@ def fused_topk(
|
|||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
@@ -355,12 +356,13 @@ def select_experts(
|
|||||||
assert (
|
assert (
|
||||||
num_token_non_padded is None
|
num_token_non_padded is None
|
||||||
), "num_token_non_padded is not yet supported in fused_topk"
|
), "num_token_non_padded is not yet supported in fused_topk"
|
||||||
assert expert_location_dispatch_info is None
|
# Qwen3MOE uses fused_topk
|
||||||
topk_weights, topk_ids = fused_topk(
|
topk_weights, topk_ids = fused_topk(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
topk=top_k,
|
topk=top_k,
|
||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@@ -690,7 +690,9 @@ def _convert_global_physical_count_to_logical_count(
|
|||||||
)
|
)
|
||||||
logical_count.scatter_add_(
|
logical_count.scatter_add_(
|
||||||
dim=2,
|
dim=2,
|
||||||
index=physical_to_logical_map.unsqueeze(0).expand(dim_extra, -1, -1),
|
index=physical_to_logical_map.unsqueeze(0)
|
||||||
|
.expand(dim_extra, -1, -1)
|
||||||
|
.to(torch.int64),
|
||||||
src=global_physical_count,
|
src=global_physical_count,
|
||||||
)
|
)
|
||||||
return logical_count
|
return logical_count
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
@@ -67,6 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||||
|
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
@@ -86,28 +88,25 @@ logger = logging.getLogger(__name__)
|
|||||||
class Qwen3MoeSparseMoeBlock(nn.Module):
|
class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
layer_id: int,
|
||||||
config: Qwen3MoeConfig,
|
config: Qwen3MoeConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.layer_id = layer_id
|
||||||
if self.tp_size > config.num_experts:
|
if self.tp_size > config.num_experts:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
f"the number of experts {config.num_experts}."
|
f"the number of experts {config.num_experts}."
|
||||||
)
|
)
|
||||||
|
|
||||||
MoEImpl = (
|
self.experts = get_moe_impl_class()(
|
||||||
DeepEPMoE
|
num_experts=config.num_experts
|
||||||
if global_server_args_dict["enable_deepep_moe"]
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.experts = MoEImpl(
|
|
||||||
num_experts=config.num_experts,
|
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
|
layer_id=layer_id,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
@@ -131,7 +130,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
if global_server_args_dict["enable_deepep_moe"]:
|
if global_server_args_dict["enable_deepep_moe"]:
|
||||||
# TODO: we will support tp < ep in the future
|
# TODO: we will support tp < ep in the future
|
||||||
self.ep_size = get_tensor_model_parallel_world_size()
|
self.ep_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_experts = config.num_experts
|
self.num_experts = (
|
||||||
|
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||||
|
)
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
self.renormalize = config.norm_topk_prob
|
self.renormalize = config.norm_topk_prob
|
||||||
|
|
||||||
@@ -139,7 +140,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
group=parallel_state.get_tp_group().device_group,
|
group=parallel_state.get_tp_group().device_group,
|
||||||
router_topk=self.top_k,
|
router_topk=self.top_k,
|
||||||
permute_fusion=True,
|
permute_fusion=True,
|
||||||
num_experts=config.num_experts,
|
num_experts=self.num_experts,
|
||||||
num_local_experts=config.num_experts // self.tp_size,
|
num_local_experts=config.num_experts // self.tp_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
params_dtype=config.torch_dtype,
|
params_dtype=config.torch_dtype,
|
||||||
@@ -157,8 +158,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_mode)
|
return self.forward_deepep(hidden_states, forward_mode)
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def get_moe_weights(self):
|
||||||
|
return [
|
||||||
|
x.data
|
||||||
|
for name, x in self.experts.named_parameters()
|
||||||
|
if name not in ["correction_bias"]
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
@@ -189,6 +196,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
use_grouped_topk=False,
|
use_grouped_topk=False,
|
||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_idx = torch.full(
|
topk_idx = torch.full(
|
||||||
@@ -408,6 +418,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
if self.info.is_sparse:
|
if self.info.is_sparse:
|
||||||
self.mlp = Qwen3MoeSparseMoeBlock(
|
self.mlp = Qwen3MoeSparseMoeBlock(
|
||||||
|
layer_id=self.layer_id,
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("mlp", prefix),
|
prefix=add_prefix("mlp", prefix),
|
||||||
@@ -685,15 +696,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
|
||||||
MoEImpl = (
|
|
||||||
DeepEPMoE
|
|
||||||
if global_server_args_dict["enable_deepep_moe"]
|
|
||||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
|
||||||
)
|
|
||||||
|
|
||||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
@@ -770,5 +773,19 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"Parameter {name} not found in params_dict")
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
|
self.routed_experts_weights_of_layer = {
|
||||||
|
layer_id: layer.mlp.get_moe_weights()
|
||||||
|
for layer_id, layer in enumerate(self.model.layers)
|
||||||
|
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_config_for_expert_location(cls, config):
|
||||||
|
return ModelConfigForExpertLocation(
|
||||||
|
num_layers=config.num_hidden_layers,
|
||||||
|
num_logical_experts=config.num_experts,
|
||||||
|
num_groups=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Qwen3MoeForCausalLM
|
EntryClass = Qwen3MoeForCausalLM
|
||||||
|
|||||||
Reference in New Issue
Block a user