Optimize Qwen3-moe model by using flashinfer fused allreduce (#9973)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -42,9 +42,15 @@ from sglang.srt.layers.moe import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
|
from sglang.srt.utils import (
|
||||||
|
is_cuda,
|
||||||
|
is_flashinfer_available,
|
||||||
|
is_sm90_supported,
|
||||||
|
is_sm100_supported,
|
||||||
|
)
|
||||||
|
|
||||||
_is_flashinfer_available = is_flashinfer_available()
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
|
_is_sm90_supported = is_cuda() and is_sm90_supported()
|
||||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
|
|
||||||
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
||||||
@@ -484,11 +490,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
||||||
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
||||||
if (
|
if (
|
||||||
_is_sm100_supported
|
(_is_sm100_supported or _is_sm90_supported)
|
||||||
and _is_flashinfer_available
|
and _is_flashinfer_available
|
||||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||||
and hidden_states.shape[0] <= 2048
|
and hidden_states.shape[0] <= 4096
|
||||||
):
|
):
|
||||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||||
hidden_states, residual
|
hidden_states, residual
|
||||||
|
|||||||
@@ -105,11 +105,14 @@ class Qwen2MoeMLP(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
):
|
):
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
|
x, _ = self.down_proj(
|
||||||
|
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe import get_moe_a2a_backend
|
from sglang.srt.layers.moe import (
|
||||||
|
get_moe_a2a_backend,
|
||||||
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||||
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||||
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
|
from sglang.srt.utils import (
|
||||||
|
add_prefix,
|
||||||
|
is_cuda,
|
||||||
|
is_flashinfer_available,
|
||||||
|
is_non_idle_and_non_empty,
|
||||||
|
)
|
||||||
|
|
||||||
Qwen3MoeConfig = None
|
Qwen3MoeConfig = None
|
||||||
|
|
||||||
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: Optional[ForwardBatch] = None,
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if not get_moe_a2a_backend().is_deepep():
|
if not get_moe_a2a_backend().is_deepep():
|
||||||
return self.forward_normal(hidden_states, use_reduce_scatter)
|
return self.forward_normal(
|
||||||
|
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
return self.forward_deepep(hidden_states, forward_batch)
|
||||||
|
|
||||||
@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
def forward_normal(
|
def forward_normal(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
topk_output = self.topk(hidden_states, router_logits)
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
if self.tp_size > 1 and not use_reduce_scatter:
|
if (
|
||||||
|
self.tp_size > 1
|
||||||
|
and not should_allreduce_fusion
|
||||||
|
and not use_reduce_scatter
|
||||||
|
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
|
||||||
|
):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
allow_reduce_scatter=True,
|
allow_reduce_scatter=True,
|
||||||
|
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
should_allreduce_fusion = (
|
||||||
|
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
||||||
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
||||||
forward_batch
|
forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
hidden_states = self.mlp(
|
||||||
|
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
|
||||||
hidden_states, residual, forward_batch
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if should_allreduce_fusion:
|
||||||
|
hidden_states._sglang_needs_allreduce_fusion = True
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
|
hidden_states, residual, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
def op_comm_prepare_attn(
|
def op_comm_prepare_attn(
|
||||||
|
|||||||
Reference in New Issue
Block a user