Support qwen3 deepep (#6120)
This commit is contained in:
@@ -607,6 +607,9 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if hidden_states.shape[0] != 0:
|
if hidden_states.shape[0] != 0:
|
||||||
|
if residual is None:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
else:
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
|
|||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
|
parallel_state,
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
@@ -54,8 +55,10 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||||
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
@@ -65,11 +68,15 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
|
)
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||||
from sglang.srt.utils import add_prefix
|
from sglang.srt.utils import DeepEPMode, add_prefix
|
||||||
|
|
||||||
Qwen3MoeConfig = None
|
Qwen3MoeConfig = None
|
||||||
|
|
||||||
@@ -92,7 +99,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
f"the number of experts {config.num_experts}."
|
f"the number of experts {config.num_experts}."
|
||||||
)
|
)
|
||||||
|
|
||||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
MoEImpl = (
|
||||||
|
DeepEPMoE
|
||||||
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||||
|
)
|
||||||
|
|
||||||
self.experts = MoEImpl(
|
self.experts = MoEImpl(
|
||||||
num_experts=config.num_experts,
|
num_experts=config.num_experts,
|
||||||
@@ -102,6 +113,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
|
**(
|
||||||
|
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||||
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
|
else {}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(
|
self.gate = ReplicatedLinear(
|
||||||
@@ -112,7 +128,37 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
prefix=add_prefix("gate", prefix),
|
prefix=add_prefix("gate", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
if global_server_args_dict["enable_deepep_moe"]:
|
||||||
|
# TODO: we will support tp < ep in the future
|
||||||
|
self.ep_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_experts = config.num_experts
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
self.renormalize = config.norm_topk_prob
|
||||||
|
|
||||||
|
self.deepep_dispatcher = DeepEPDispatcher(
|
||||||
|
group=parallel_state.get_tp_group().device_group,
|
||||||
|
router_topk=self.top_k,
|
||||||
|
permute_fusion=True,
|
||||||
|
num_experts=config.num_experts,
|
||||||
|
num_local_experts=config.num_experts // self.tp_size,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
params_dtype=config.torch_dtype,
|
||||||
|
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||||
|
async_finish=True, # TODO
|
||||||
|
return_recv_hook=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
if not global_server_args_dict["enable_deepep_moe"]:
|
||||||
|
return self.forward_normal(hidden_states)
|
||||||
|
else:
|
||||||
|
return self.forward_deepep(hidden_states, forward_mode)
|
||||||
|
|
||||||
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
@@ -126,6 +172,68 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
|
|
||||||
|
def forward_deepep(
|
||||||
|
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if (
|
||||||
|
forward_mode is not None
|
||||||
|
and not forward_mode.is_idle()
|
||||||
|
and hidden_states.shape[0] > 0
|
||||||
|
):
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
|
topk_weights, topk_idx = select_experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=self.top_k,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
renormalize=self.renormalize,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_idx = torch.full(
|
||||||
|
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||||
|
)
|
||||||
|
topk_weights = torch.empty(
|
||||||
|
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
if self.ep_size > 1:
|
||||||
|
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||||
|
(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
reorder_topk_ids,
|
||||||
|
num_recv_tokens_per_expert,
|
||||||
|
seg_indptr,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
) = self.deepep_dispatcher.dispatch(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
forward_mode=forward_mode,
|
||||||
|
)
|
||||||
|
final_hidden_states = self.experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
reorder_topk_ids=reorder_topk_ids,
|
||||||
|
seg_indptr=seg_indptr,
|
||||||
|
masked_m=masked_m,
|
||||||
|
expected_m=expected_m,
|
||||||
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||||
|
forward_mode=forward_mode,
|
||||||
|
)
|
||||||
|
if self.ep_size > 1:
|
||||||
|
final_hidden_states = self.deepep_dispatcher.combine(
|
||||||
|
final_hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
forward_mode,
|
||||||
|
)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MoeAttention(nn.Module):
|
class Qwen3MoeAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -403,7 +511,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||||
|
|
||||||
# TODO: use reduce-scatter in MLP to avoid this scatter
|
# TODO: use reduce-scatter in MLP to avoid this scatter
|
||||||
# Scatter
|
# Scatter
|
||||||
@@ -577,7 +685,13 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
|
MoEImpl = (
|
||||||
|
DeepEPMoE
|
||||||
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||||
|
)
|
||||||
|
|
||||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
|
|||||||
Reference in New Issue
Block a user