Optimize Qwen3-moe model by using flashinfer fused allreduce (#9973)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -42,9 +42,15 @@ from sglang.srt.layers.moe import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
|
||||
from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_sm90_supported,
|
||||
is_sm100_supported,
|
||||
)
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm90_supported = is_cuda() and is_sm90_supported()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
||||
@@ -484,11 +490,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
||||
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
||||
if (
|
||||
_is_sm100_supported
|
||||
(_is_sm100_supported or _is_sm90_supported)
|
||||
and _is_flashinfer_available
|
||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||
and hidden_states.shape[0] <= 2048
|
||||
and hidden_states.shape[0] <= 4096
|
||||
):
|
||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||
hidden_states, residual
|
||||
|
||||
Reference in New Issue
Block a user