support qwen3-next-fp8 deepep (#10622)
This commit is contained in:
@@ -25,12 +25,14 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_moe_expert_parallel_world_size,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||||
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.communicator import (
|
from sglang.srt.layers.communicator import (
|
||||||
LayerCommunicator,
|
LayerCommunicator,
|
||||||
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
reduce_results: bool = True,
|
reduce_results: bool = True,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
tp_rank: Optional[int] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("gate_up_proj", prefix),
|
prefix=add_prefix("gate_up_proj", prefix),
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
tp_size=tp_size,
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=reduce_results,
|
reduce_results=reduce_results,
|
||||||
prefix=add_prefix("down_proj", prefix),
|
prefix=add_prefix("down_proj", prefix),
|
||||||
|
tp_rank=tp_rank,
|
||||||
|
tp_size=tp_size,
|
||||||
)
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
self.experts = get_moe_impl_class(quant_config)(
|
self.experts = get_moe_impl_class(quant_config)(
|
||||||
layer_id=self.layer_id,
|
layer_id=self.layer_id,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
num_experts=config.num_experts,
|
num_experts=config.num_experts
|
||||||
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
prefix=add_prefix("shared_expert", prefix),
|
prefix=add_prefix("shared_expert", prefix),
|
||||||
|
**(
|
||||||
|
dict(tp_rank=0, tp_size=1)
|
||||||
|
if get_moe_a2a_backend().is_deepep()
|
||||||
|
else {}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.shared_expert = None
|
self.shared_expert = None
|
||||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
|
|
||||||
|
if get_moe_a2a_backend().is_deepep():
|
||||||
|
# TODO: we will support tp < ep in the future
|
||||||
|
self.ep_size = get_moe_expert_parallel_world_size()
|
||||||
|
self.num_experts = (
|
||||||
|
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||||
|
)
|
||||||
|
self.top_k = config.num_experts_per_tok
|
||||||
|
|
||||||
|
def get_moe_weights(self):
|
||||||
|
return [
|
||||||
|
x.data
|
||||||
|
for name, x in self.experts.named_parameters()
|
||||||
|
if name not in ["correction_bias"]
|
||||||
|
]
|
||||||
|
|
||||||
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
||||||
shared_output = None
|
shared_output = None
|
||||||
if self.shared_expert is not None:
|
if self.shared_expert is not None:
|
||||||
@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
return shared_output
|
return shared_output
|
||||||
|
|
||||||
|
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
|
||||||
|
shared_output = None
|
||||||
|
if hidden_states.shape[0] > 0:
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
|
topk_weights, topk_idx, _ = self.topk(
|
||||||
|
hidden_states,
|
||||||
|
router_logits,
|
||||||
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||||
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
||||||
|
hidden_states.device
|
||||||
|
)
|
||||||
|
final_hidden_states = self.experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
topk_idx=topk_idx,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
if shared_output is not None:
|
||||||
|
final_hidden_states.add_(shared_output)
|
||||||
|
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
|
if get_moe_a2a_backend().is_deepep():
|
||||||
|
return self._forward_deepep(hidden_states, forward_batch)
|
||||||
|
|
||||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||||
if (
|
if (
|
||||||
self.alt_stream is not None
|
self.alt_stream is not None
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
|
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
|
||||||
from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
|
from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
|
||||||
@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import (
|
|||||||
sharded_weight_loader,
|
sharded_weight_loader,
|
||||||
)
|
)
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
|
||||||
from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
|
from sglang.srt.utils import (
|
||||||
|
LazyValue,
|
||||||
|
add_prefix,
|
||||||
|
is_cuda,
|
||||||
|
is_npu,
|
||||||
|
make_layers,
|
||||||
|
set_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -849,13 +857,14 @@ class Qwen3NextModel(nn.Module):
|
|||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
layer_id=i,
|
hidden_states, residual = layer(
|
||||||
positions=positions,
|
layer_id=i,
|
||||||
hidden_states=hidden_states,
|
positions=positions,
|
||||||
residual=residual,
|
hidden_states=hidden_states,
|
||||||
forward_batch=forward_batch,
|
residual=residual,
|
||||||
)
|
forward_batch=forward_batch,
|
||||||
|
)
|
||||||
|
|
||||||
if not forward_batch.forward_mode.is_idle():
|
if not forward_batch.forward_mode.is_idle():
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@@ -901,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module):
|
|||||||
self.lm_head = self.lm_head.float()
|
self.lm_head = self.lm_head.float()
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
self._routed_experts_weights_of_layer = LazyValue(
|
||||||
|
lambda: {
|
||||||
|
layer_id: layer.mlp.get_moe_weights()
|
||||||
|
for layer_id, layer in enumerate(self.model.layers)
|
||||||
|
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def routed_experts_weights_of_layer(self):
|
||||||
|
return self._routed_experts_weights_of_layer.value
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user