[2/2] Fuse routed scaling factor into select_experts (#8690)
This commit is contained in:
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
|
||||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -923,6 +924,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
for shard_id in ["w1", "w2", "w3"]
|
for shard_id in ["w1", "w2", "w3"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def should_fuse_routed_scaling_factor_in_topk(self):
|
||||||
|
return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or (
|
||||||
|
isinstance(self.quant_method, Fp8MoEMethod)
|
||||||
|
and self.quant_method.use_cutlass_fused_experts_fp8
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashInferFusedMoE(FusedMoE):
|
class FlashInferFusedMoE(FusedMoE):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|||||||
@@ -197,6 +197,7 @@ class TopK(CustomOp):
|
|||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||||
@@ -215,6 +216,7 @@ class TopK(CustomOp):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
correction_bias=correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
||||||
@@ -433,6 +435,7 @@ def grouped_topk_gpu(
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
@@ -480,6 +483,8 @@ def grouped_topk_gpu(
|
|||||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||||
)
|
)
|
||||||
topk_weights = topk_weights / topk_weights_sum
|
topk_weights = topk_weights / topk_weights_sum
|
||||||
|
if apply_routed_scaling_factor_on_output:
|
||||||
|
topk_weights *= routed_scaling_factor
|
||||||
|
|
||||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||||
@@ -528,6 +533,7 @@ def biased_grouped_topk_impl(
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
@@ -579,6 +585,8 @@ def biased_grouped_topk_impl(
|
|||||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||||
)
|
)
|
||||||
topk_weights = topk_weights / topk_weights_sum
|
topk_weights = topk_weights / topk_weights_sum
|
||||||
|
if apply_routed_scaling_factor_on_output:
|
||||||
|
topk_weights *= routed_scaling_factor
|
||||||
|
|
||||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||||
@@ -621,6 +629,7 @@ def biased_grouped_topk_gpu(
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
assert (
|
assert (
|
||||||
routed_scaling_factor is not None
|
routed_scaling_factor is not None
|
||||||
@@ -640,6 +649,7 @@ def biased_grouped_topk_gpu(
|
|||||||
topk,
|
topk,
|
||||||
num_fused_shared_experts,
|
num_fused_shared_experts,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
|
apply_routed_scaling_factor_on_output,
|
||||||
)
|
)
|
||||||
# TODO merge into kernel
|
# TODO merge into kernel
|
||||||
if (expert_location_dispatch_info is not None) or (
|
if (expert_location_dispatch_info is not None) or (
|
||||||
@@ -650,6 +660,7 @@ def biased_grouped_topk_gpu(
|
|||||||
)
|
)
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
elif _use_aiter:
|
elif _use_aiter:
|
||||||
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||||
token = gating_output.shape[0]
|
token = gating_output.shape[0]
|
||||||
device = gating_output.device
|
device = gating_output.device
|
||||||
assert (
|
assert (
|
||||||
@@ -681,6 +692,7 @@ def biased_grouped_topk_gpu(
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
num_token_non_padded=num_token_non_padded,
|
num_token_non_padded=num_token_non_padded,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -743,6 +755,9 @@ def select_experts(
|
|||||||
correction_bias = topk_config.correction_bias
|
correction_bias = topk_config.correction_bias
|
||||||
torch_native = topk_config.torch_native
|
torch_native = topk_config.torch_native
|
||||||
routed_scaling_factor = topk_config.routed_scaling_factor
|
routed_scaling_factor = topk_config.routed_scaling_factor
|
||||||
|
apply_routed_scaling_factor_on_output = (
|
||||||
|
topk_config.apply_routed_scaling_factor_on_output
|
||||||
|
)
|
||||||
|
|
||||||
router_logits, correction_bias = (
|
router_logits, correction_bias = (
|
||||||
expert_location_dispatch.transform_select_experts_inputs(
|
expert_location_dispatch.transform_select_experts_inputs(
|
||||||
@@ -768,6 +783,7 @@ def select_experts(
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
num_token_non_padded=num_token_non_padded,
|
num_token_non_padded=num_token_non_padded,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_ids = biased_grouped_topk(
|
topk_weights, topk_ids = biased_grouped_topk(
|
||||||
@@ -782,12 +798,14 @@ def select_experts(
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
num_token_non_padded=num_token_non_padded,
|
num_token_non_padded=num_token_non_padded,
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||||
)
|
)
|
||||||
elif torch_native and custom_routing_function is None:
|
elif torch_native and custom_routing_function is None:
|
||||||
assert (
|
assert (
|
||||||
num_token_non_padded is None
|
num_token_non_padded is None
|
||||||
), "num_token_non_padded is not yet supported in fused_topk_native"
|
), "num_token_non_padded is not yet supported in fused_topk_native"
|
||||||
assert expert_location_dispatch_info is None
|
assert expert_location_dispatch_info is None
|
||||||
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||||
topk_weights, topk_ids = fused_topk_native(
|
topk_weights, topk_ids = fused_topk_native(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
@@ -795,6 +813,7 @@ def select_experts(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
)
|
)
|
||||||
elif custom_routing_function is None:
|
elif custom_routing_function is None:
|
||||||
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||||
# Qwen3MOE uses fused_topk
|
# Qwen3MOE uses fused_topk
|
||||||
topk_weights, topk_ids = fused_topk(
|
topk_weights, topk_ids = fused_topk(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@@ -809,6 +828,7 @@ def select_experts(
|
|||||||
num_token_non_padded is None
|
num_token_non_padded is None
|
||||||
), "num_token_non_padded is not yet supported in custom_routing_function"
|
), "num_token_non_padded is not yet supported in custom_routing_function"
|
||||||
assert expert_location_dispatch_info is None
|
assert expert_location_dispatch_info is None
|
||||||
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||||
topk_weights, topk_ids = custom_routing_function(
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
|
|||||||
@@ -514,6 +514,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
self.use_cutlass_fused_experts_fp8 = (
|
||||||
|
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
||||||
|
and self.cutlass_fp8_supported
|
||||||
|
and self.block_quant
|
||||||
|
and (is_sm100_supported() or is_sm90_supported())
|
||||||
|
)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -1021,12 +1027,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
if ret is not None:
|
if ret is not None:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
if (
|
if self.use_cutlass_fused_experts_fp8:
|
||||||
get_bool_env_var("SGLANG_CUTLASS_MOE")
|
|
||||||
and self.cutlass_fp8_supported
|
|
||||||
and self.block_quant
|
|
||||||
and (is_sm100_supported() or is_sm90_supported())
|
|
||||||
):
|
|
||||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
@@ -1053,9 +1054,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.problem_sizes2,
|
self.problem_sizes2,
|
||||||
use_fp8_blockscale=True,
|
use_fp8_blockscale=True,
|
||||||
)
|
)
|
||||||
# TODO: Fuse into select_experts
|
# Scale by routed_scaling_factor is fused into select_experts.
|
||||||
if moe_runner_config.routed_scaling_factor is not None:
|
|
||||||
output *= moe_runner_config.routed_scaling_factor
|
|
||||||
return output
|
return output
|
||||||
# Expert fusion with FP8 quantization
|
# Expert fusion with FP8 quantization
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
|
|||||||
@@ -1305,8 +1305,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
tp_rank=layer.moe_tp_rank,
|
tp_rank=layer.moe_tp_rank,
|
||||||
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
||||||
)[0]
|
)[0]
|
||||||
if moe_runner_config.routed_scaling_factor is not None:
|
# Scale by routed_scaling_factor is fused into select_experts.
|
||||||
output *= moe_runner_config.routed_scaling_factor
|
|
||||||
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||||
output, global_output = get_local_dp_buffer(), output
|
output, global_output = get_local_dp_buffer(), output
|
||||||
get_tp_group().reduce_scatterv(
|
get_tp_group().reduce_scatterv(
|
||||||
@@ -1332,6 +1331,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
params=layer.cutlass_moe_params,
|
params=layer.cutlass_moe_params,
|
||||||
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||||
).to(x.dtype)
|
).to(x.dtype)
|
||||||
if moe_runner_config.routed_scaling_factor is not None:
|
# Scale by routed_scaling_factor is fused into select_experts.
|
||||||
output *= moe_runner_config.routed_scaling_factor
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -319,17 +319,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||||
)
|
)
|
||||||
|
|
||||||
self.topk = TopK(
|
|
||||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
|
||||||
renormalize=config.norm_topk_prob,
|
|
||||||
use_grouped_topk=True,
|
|
||||||
num_expert_group=config.n_group,
|
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
||||||
topk_group=config.topk_group,
|
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.experts = get_moe_impl_class()(
|
self.experts = get_moe_impl_class()(
|
||||||
num_experts=config.n_routed_experts
|
num_experts=config.n_routed_experts
|
||||||
+ self.num_fused_shared_experts
|
+ self.num_fused_shared_experts
|
||||||
@@ -344,6 +333,18 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.topk = TopK(
|
||||||
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
use_grouped_topk=True,
|
||||||
|
num_expert_group=config.n_group,
|
||||||
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
||||||
|
)
|
||||||
|
|
||||||
self.shared_experts_is_int8 = False
|
self.shared_experts_is_int8 = False
|
||||||
self.shared_experts_is_fp8 = False
|
self.shared_experts_is_fp8 = False
|
||||||
self.shared_experts_weight_block_size = None
|
self.shared_experts_weight_block_size = None
|
||||||
|
|||||||
@@ -19,7 +19,10 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
|
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
|
||||||
def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
|
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True])
|
||||||
|
def test_moe_fused_gate_combined(
|
||||||
|
seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
|
||||||
|
):
|
||||||
num_experts, num_expert_group, topk_group, topk = params
|
num_experts, num_expert_group, topk_group, topk = params
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
@@ -37,6 +40,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
|
|||||||
topk=topk,
|
topk=topk,
|
||||||
num_fused_shared_experts=num_fused_shared_experts,
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
routed_scaling_factor=2.5,
|
routed_scaling_factor=2.5,
|
||||||
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||||
)
|
)
|
||||||
ref_output, ref_indices = biased_grouped_topk(
|
ref_output, ref_indices = biased_grouped_topk(
|
||||||
scores,
|
scores,
|
||||||
@@ -48,6 +52,7 @@ def test_moe_fused_gate_combined(seq_length, params, num_fused_shared_experts):
|
|||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_fused_shared_experts=num_fused_shared_experts,
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
routed_scaling_factor=2.5,
|
routed_scaling_factor=2.5,
|
||||||
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
|
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
|
||||||
|
|||||||
Reference in New Issue
Block a user