Expert Parallelism for GPT-OSS (#8944)

This commit is contained in:
Cheng Wan
2025-08-08 00:46:42 -07:00
committed by GitHub
parent 444013585d
commit 1d24db8348
8 changed files with 269 additions and 119 deletions

View File

@@ -8,6 +8,7 @@ import logging
from typing import TYPE_CHECKING, List, Optional
import torch
import triton.language as tl
from torch.nn.parameter import Parameter
from sglang.srt.layers.quantization.base_config import (
@@ -24,6 +25,7 @@ from sglang.srt.utils import (
is_cuda,
is_flashinfer_available,
is_hip,
is_triton_kernels_available,
log_info_on_rank0,
next_power_of_2,
round_up,
@@ -31,7 +33,7 @@ from sglang.srt.utils import (
)
_is_sm100_supported = is_cuda() and is_sm100_supported()
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
has_triton_kernels = is_triton_kernels_available()
if is_flashinfer_available():
@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig):
):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
use_flashinfer = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
return Mxfp4MoEMethod(
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
)
return Mxfp4MoEMethod(prefix)
else:
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(
self,
use_triton_kernels: bool = True,
with_bias: bool = True,
use_flashinfer: bool = False,
prefix: str,
):
from sglang.srt.managers.schedule_batch import global_server_args_dict
super().__init__()
self.topk_indices_dtype = None
self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias
self.use_flashinfer = use_flashinfer
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
self.with_bias = False
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
with_bias: bool = False,
**extra_weight_attrs,
):
# print(f"hi {self=} create_weights {layer=}")
self.num_experts = num_experts
weight_dtype = torch.uint8
scale_dtype = torch.uint8
self.with_bias = with_bias
mxfp4_block = 32
# pad the intermediate size to be a multiple of 2 * mxfp4_block
@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // 2,
dtype=weight_dtype,
@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad,
hidden_size // mxfp4_block,
dtype=scale_dtype,
@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.bfloat16,
),
@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
hidden_size,
intermediate_size_per_partition_after_pad // 2,
dtype=weight_dtype,
@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_weight_scale = torch.nn.Parameter(
torch.zeros(
num_experts,
layer.num_local_experts,
hidden_size,
intermediate_size_per_partition_after_pad // mxfp4_block,
dtype=scale_dtype,
@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w2_weight_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16),
torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16),
requires_grad=False,
)
layer.register_parameter("w2_weight_bias", w2_weight_bias)
@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
return
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
if self.use_triton_kernels:
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
num_warps = 8
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
num_warps = 8
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
layer.w2_weight, layer.w2_weight_scale, num_warps
)
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
self.w13_precision_config = PrecisionConfig(
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
)
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
# need to delete the original weights to save memory on single GPU
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = None
layer.w2_weight = None
self.w13_weight_triton_tensor = w13_weight
self.w2_weight_triton_tensor = w2_weight
del layer.w13_weight
del layer.w2_weight
else:
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp
w13_weight = upcast_from_mxfp(
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
)
w2_weight = upcast_from_mxfp(
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
)
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w2_weight_scale
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
torch.cuda.empty_cache()
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
self.num_experts,
layer.num_experts,
top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
0, # local_expert_offset
self.num_experts, # local num experts
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
layer.num_local_experts, # local num experts
None,
self._get_tile_tokens_dim(x, top_k),
1, # routing_method_type, renormalize
@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return trtllm_gen_output
if self.use_triton_kernels:
assert (
layer.moe_ep_size == 1
), "Expert parallel is not supported when using triton kernels"
if self.with_bias:
# TODO why we do not put weights on layer?
assert layer.w13_weight is None
assert layer.w2_weight is None
return self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=self.w13_weight_triton_tensor,
@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_output=topk_output,
)
else:
raise NotImplementedError()
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)