Expert Parallelism for GPT-OSS (#8944)
This commit is contained in:
@@ -8,6 +8,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
@@ -24,6 +25,7 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
is_triton_kernels_available,
|
||||
log_info_on_rank0,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
@@ -31,7 +33,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
has_triton_kernels = is_triton_kernels_available()
|
||||
|
||||
|
||||
if is_flashinfer_available():
|
||||
@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig):
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
use_flashinfer = global_server_args_dict.get(
|
||||
"enable_flashinfer_mxfp4_moe", False
|
||||
)
|
||||
return Mxfp4MoEMethod(
|
||||
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
|
||||
)
|
||||
return Mxfp4MoEMethod(prefix)
|
||||
else:
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_triton_kernels: bool = True,
|
||||
with_bias: bool = True,
|
||||
use_flashinfer: bool = False,
|
||||
prefix: str,
|
||||
):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
self.with_bias = with_bias
|
||||
self.use_flashinfer = use_flashinfer
|
||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||
self.with_bias = False
|
||||
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
with_bias: bool = False,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# print(f"hi {self=} create_weights {layer=}")
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
self.with_bias = with_bias
|
||||
mxfp4_block = 32
|
||||
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
dtype=weight_dtype,
|
||||
@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16),
|
||||
torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
||||
@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
return
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
if self.use_triton_kernels:
|
||||
|
||||
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
||||
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
||||
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
||||
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
||||
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
||||
|
||||
num_warps = 8
|
||||
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
||||
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
num_warps = 8
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||
)
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||
)
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
else:
|
||||
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp
|
||||
|
||||
w13_weight = upcast_from_mxfp(
|
||||
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||
)
|
||||
w2_weight = upcast_from_mxfp(
|
||||
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||
)
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
del layer.w13_weight_scale
|
||||
del layer.w2_weight_scale
|
||||
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
self.num_experts,
|
||||
layer.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
0, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||
layer.num_local_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
1, # routing_method_type, renormalize
|
||||
@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
return trtllm_gen_output
|
||||
|
||||
if self.use_triton_kernels:
|
||||
assert (
|
||||
layer.moe_ep_size == 1
|
||||
), "Expert parallel is not supported when using triton kernels"
|
||||
if self.with_bias:
|
||||
# TODO why we do not put weights on layer?
|
||||
assert layer.w13_weight is None
|
||||
assert layer.w2_weight is None
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
topk_output=topk_output,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user