[7/N] MoE Refactor: the implementation of new framework (#9269)
This commit is contained in:
@@ -10,10 +10,14 @@ from torch.nn.parameter import Parameter
|
||||
from sglang.srt.distributed import get_tp_group
|
||||
from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer
|
||||
from sglang.srt.layers.moe import (
|
||||
MoeRunner,
|
||||
MoeRunnerBackend,
|
||||
MoeRunnerConfig,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -39,8 +43,10 @@ from sglang.srt.utils import is_cuda, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
@@ -322,7 +328,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
@@ -338,7 +344,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=weight_dtype
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=2,
|
||||
output_dim=1,
|
||||
@@ -348,7 +357,10 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w2_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts, hidden_size, intermediate_size, dtype=weight_dtype
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=weight_dtype,
|
||||
),
|
||||
input_dim=2,
|
||||
output_dim=1,
|
||||
@@ -414,28 +426,28 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
|
||||
# Requantize each expert's weights using the combined scale
|
||||
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
|
||||
# where the first intermediate_size rows are w1, the next are w3
|
||||
intermediate_size = layer.w13_weight.shape[1] // 2
|
||||
# w13_weight has shape (num_experts, 2 * intermediate_size_per_partition, hidden_size)
|
||||
# where the first intermediate_size_per_partition rows are w1, the next are w3
|
||||
intermediate_size_per_partition = layer.w13_weight.shape[1] // 2
|
||||
for expert_id in range(layer.w13_weight.shape[0]):
|
||||
start = 0
|
||||
for shard_id in range(2): # w1 and w3
|
||||
# Dequantize using the original scale for this shard
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][
|
||||
start : start + intermediate_size, :
|
||||
start : start + intermediate_size_per_partition, :
|
||||
],
|
||||
layer.w13_weight_scale[expert_id][shard_id],
|
||||
)
|
||||
# Requantize using the combined max scale
|
||||
(
|
||||
layer.w13_weight[expert_id][
|
||||
start : start + intermediate_size, :
|
||||
start : start + intermediate_size_per_partition, :
|
||||
],
|
||||
_,
|
||||
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
||||
|
||||
start += intermediate_size
|
||||
start += intermediate_size_per_partition
|
||||
|
||||
# Update the scale parameter to be per-expert instead of per-shard
|
||||
layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
|
||||
@@ -457,29 +469,31 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False, # ModelOpt uses per-tensor quantization
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
per_channel_quant=False,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a13_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
|
||||
class ModelOptFp4Config(QuantizationConfig):
|
||||
"""Config class for FP4."""
|
||||
@@ -1278,21 +1292,32 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
# FlashInfer CUTLASS kernel assumes [Up, Gate] Proj as W13
|
||||
return self.enable_flashinfer_cutlass_moe
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
moe_runner_config = self.moe_runner_config
|
||||
|
||||
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
|
||||
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
|
||||
# This layer was processed with flashinfer TRTLLM - delegate to its own forward
|
||||
return layer.forward(x, topk_output)
|
||||
return StandardCombineInput(hidden_states=layer.forward(x, topk_output))
|
||||
|
||||
if self.enable_flashinfer_cutlass_moe:
|
||||
assert (
|
||||
@@ -1345,13 +1370,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
tp_rank=layer.moe_tp_rank,
|
||||
tune_max_num_tokens=next_power_of_2(x.shape[0]),
|
||||
)[0]
|
||||
# Scale by routed_scaling_factor is fused into select_experts.
|
||||
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||
output, global_output = get_local_dp_buffer(), output
|
||||
get_tp_group().reduce_scatterv(
|
||||
global_output, output=output, sizes=get_dp_global_num_tokens()
|
||||
)
|
||||
return output
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
@@ -1372,4 +1396,5 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||
).to(x.dtype)
|
||||
# Scale by routed_scaling_factor is fused into select_experts.
|
||||
return output
|
||||
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
Reference in New Issue
Block a user