[7/N] MoE Refactor: the implementation of new framework (#9269)
This commit is contained in:
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
@@ -59,8 +61,10 @@ if is_flashinfer_available():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -283,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
with_bias: bool = False,
|
||||
**extra_weight_attrs,
|
||||
@@ -296,26 +300,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
intermediate_size_per_partition_after_pad = intermediate_size
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
if _is_sm100_supported:
|
||||
if self.use_flashinfer:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, 256
|
||||
intermediate_size_per_partition, 256
|
||||
)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, 64
|
||||
intermediate_size_per_partition, 64
|
||||
)
|
||||
elif has_triton_kernels:
|
||||
# TODO: this is a hack to make
|
||||
# intermediate_size_per_partition_after_pad the same as the
|
||||
# per_rank_intermediate_size during weight loading
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size, mxfp4_block
|
||||
intermediate_size_per_partition, mxfp4_block
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition_after_pad
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
@@ -410,31 +414,35 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
assert (
|
||||
layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[1]
|
||||
== self.intermediate_size_per_partition * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2
|
||||
)
|
||||
assert (
|
||||
layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size_per_partition * 2
|
||||
and layer.w13_weight_scale.shape[2] == self.hidden_size // sf_block_size
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size
|
||||
and layer.w2_weight.shape[2] == self.intermediate_size // 2
|
||||
and layer.w2_weight.shape[2]
|
||||
== self.intermediate_size_per_partition // 2
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size
|
||||
== self.intermediate_size_per_partition // sf_block_size
|
||||
)
|
||||
assert (
|
||||
layer.w13_weight_bias.dim() == 2
|
||||
and layer.w13_weight_bias.shape[0] == self.num_experts
|
||||
and layer.w13_weight_bias.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight_bias.shape[1]
|
||||
== self.intermediate_size_per_partition * 2
|
||||
)
|
||||
assert (
|
||||
layer.w2_weight_bias.dim() == 2
|
||||
@@ -511,7 +519,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
torch.stack(gemm1_scales_mxfp4_shuffled)
|
||||
.reshape(
|
||||
self.num_experts,
|
||||
2 * self.intermediate_size,
|
||||
2 * self.intermediate_size_per_partition,
|
||||
self.hidden_size // sf_block_size,
|
||||
)
|
||||
.view(torch.float8_e4m3fn)
|
||||
@@ -523,7 +531,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
.reshape(
|
||||
self.num_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size // sf_block_size,
|
||||
self.intermediate_size_per_partition // sf_block_size,
|
||||
)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
@@ -613,16 +621,26 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
moe_runner_config = self.moe_runner_config
|
||||
|
||||
if self.use_flashinfer:
|
||||
# When bf16 mode is enabled, we don't need to quantize the input,
|
||||
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
|
||||
@@ -674,7 +692,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
top_k,
|
||||
None, # n_group # TODO: support n_group
|
||||
None, # topk_group # TODO: support topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
self.intermediate_size_per_partition, # padded to multiple of 256
|
||||
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||
layer.num_local_experts, # local num experts
|
||||
None,
|
||||
@@ -682,14 +700,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
1, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
||||
return StandardCombineInput(hidden_states=trtllm_gen_output)
|
||||
|
||||
if self.use_triton_kernels:
|
||||
assert (
|
||||
layer.moe_ep_size == 1
|
||||
), "Expert parallel is not supported when using triton kernels"
|
||||
if self.with_bias:
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
output = self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w1_pcg=self.w13_precision_config,
|
||||
@@ -701,25 +719,22 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
else:
|
||||
return self.triton_kernel_moe_forward(
|
||||
output = self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
else:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
quant_info = TritonMoeQuantInfo(
|
||||
w13_weight=layer.w13_weight,
|
||||
w2_weight=layer.w2_weight,
|
||||
w13_weight_bias=layer.w13_weight_bias,
|
||||
w2_weight_bias=layer.w2_weight_bias,
|
||||
)
|
||||
return self.runner.run(dispatch_output, quant_info)
|
||||
|
||||
|
||||
class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
@@ -798,7 +813,7 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return w, mx_scales
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w13, w13_mx_scales = self.mxfp4_quantize(layer.w13_weight.data)
|
||||
w2, w2_mx_scales = self.mxfp4_quantize(layer.w2_weight.data)
|
||||
|
||||
@@ -808,19 +823,27 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_mx_scales, requires_grad=False)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if _is_hip:
|
||||
topk_weights = topk_weights.to(
|
||||
torch.float32
|
||||
) # aiter's moe_sorting requires topk_weights to be FP32
|
||||
return fused_moe(
|
||||
output = fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@@ -831,8 +854,9 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=(
|
||||
ActivationType.Silu
|
||||
if moe_runner_config.activation == "silu"
|
||||
if self.moe_runner_config.activation == "silu"
|
||||
else ActivationType.Gelu
|
||||
),
|
||||
doweight_stage1=False,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
Reference in New Issue
Block a user