support qwen3-next-fp8 deepep (#10622)
This commit is contained in:
@@ -25,12 +25,14 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.communicator import (
|
||||
LayerCommunicator,
|
||||
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
self.experts = get_moe_impl_class(quant_config)(
|
||||
layer_id=self.layer_id,
|
||||
top_k=config.num_experts_per_tok,
|
||||
num_experts=config.num_experts,
|
||||
num_experts=config.num_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
quant_config=quant_config,
|
||||
@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_expert", prefix),
|
||||
**(
|
||||
dict(tp_rank=0, tp_size=1)
|
||||
if get_moe_a2a_backend().is_deepep()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
x.data
|
||||
for name, x in self.experts.named_parameters()
|
||||
if name not in ["correction_bias"]
|
||||
]
|
||||
|
||||
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
)
|
||||
return shared_output
|
||||
|
||||
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
|
||||
shared_output = None
|
||||
if hidden_states.shape[0] > 0:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
||||
hidden_states.device
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states.add_(shared_output)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if get_moe_a2a_backend().is_deepep():
|
||||
return self._forward_deepep(hidden_states, forward_batch)
|
||||
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||
if (
|
||||
self.alt_stream is not None
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
|
||||
from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
|
||||
@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import (
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
LazyValue,
|
||||
add_prefix,
|
||||
is_cuda,
|
||||
is_npu,
|
||||
make_layers,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_is_cuda = is_cuda()
|
||||
@@ -849,13 +857,14 @@ class Qwen3NextModel(nn.Module):
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
layer_id=i,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
hidden_states, residual = layer(
|
||||
layer_id=i,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
@@ -901,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module):
|
||||
self.lm_head = self.lm_head.float()
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
self._routed_experts_weights_of_layer = LazyValue(
|
||||
lambda: {
|
||||
layer_id: layer.mlp.get_moe_weights()
|
||||
for layer_id, layer in enumerate(self.model.layers)
|
||||
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def routed_experts_weights_of_layer(self):
|
||||
return self._routed_experts_weights_of_layer.value
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user