Expert Parallelism for GPT-OSS (#8944)
This commit is contained in:
@@ -76,6 +76,9 @@ class EPMoE(FusedMoE):
|
||||
prefix: str = "",
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
@@ -91,6 +94,9 @@ class EPMoE(FusedMoE):
|
||||
activation=activation,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
with_bias=with_bias,
|
||||
)
|
||||
|
||||
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
||||
|
||||
@@ -319,6 +319,7 @@ def fused_moe_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
bias_ptr,
|
||||
c_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
@@ -340,6 +341,8 @@ def fused_moe_kernel(
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_bias_e,
|
||||
stride_bias_n,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_asm,
|
||||
@@ -449,6 +452,10 @@ def fused_moe_kernel(
|
||||
+ off_experts * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
if bias_ptr is not None:
|
||||
bias = tl.load(
|
||||
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
|
||||
)
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||
@@ -526,18 +533,20 @@ def fused_moe_kernel(
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if use_int8_w8a16:
|
||||
accumulator *= b_scale
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k == 0 or group_n == 0:
|
||||
accumulator *= a_scale * b_scale
|
||||
|
||||
if bias_ptr is not None:
|
||||
accumulator += bias
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
accumulator *= moe_weight[:, None]
|
||||
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@@ -622,6 +631,7 @@ def moe_align_block_size(
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
|
||||
):
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
assert bias is None
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
bias,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
bias.stride(0) if bias is not None else 0,
|
||||
bias.stride(1) if bias is not None else 0,
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
@@ -994,6 +1008,8 @@ def inplace_fused_experts(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> None:
|
||||
fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
True,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
|
||||
block_shape,
|
||||
False,
|
||||
routed_scaling_factor,
|
||||
activation_alpha,
|
||||
swiglu_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
False,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
|
||||
block_shape,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -1157,6 +1195,8 @@ def fused_experts(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -1174,6 +1214,8 @@ def fused_experts(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
):
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if inplace:
|
||||
@@ -1184,6 +1226,8 @@ def fused_experts(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
@@ -1199,6 +1243,8 @@ def fused_experts(
|
||||
a2_scale,
|
||||
block_shape,
|
||||
routed_scaling_factor,
|
||||
activation_alpha,
|
||||
swiglu_limit,
|
||||
)
|
||||
return hidden_states
|
||||
else:
|
||||
@@ -1208,6 +1254,8 @@ def fused_experts(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
@@ -1224,6 +1272,8 @@ def fused_experts(
|
||||
block_shape,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
|
||||
out.mul_(routed_scaling_factor)
|
||||
|
||||
|
||||
@torch.compile
|
||||
def swiglu_with_alpha_and_limit(x, alpha, limit):
|
||||
gate, up = x[..., ::2], x[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=limit)
|
||||
up = up.clamp(min=-limit, max=limit)
|
||||
return gate * torch.sigmoid(gate * alpha) * (up + 1)
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -1342,6 +1402,8 @@ def fused_experts_impl(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
):
|
||||
padded_size = padding_size
|
||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||
@@ -1353,7 +1415,7 @@ def fused_experts_impl(
|
||||
else:
|
||||
assert (
|
||||
hidden_states.shape[1] == w1.shape[2] - padded_size
|
||||
), "Hidden size mismatch"
|
||||
), f"Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
@@ -1449,6 +1511,7 @@ def fused_experts_impl(
|
||||
invoke_fused_moe_kernel(
|
||||
curr_hidden_states,
|
||||
w1,
|
||||
b1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
@@ -1470,13 +1533,24 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
if activation == "silu":
|
||||
if _is_cuda:
|
||||
if activation_alpha is not None:
|
||||
assert swiglu_limit is not None
|
||||
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
||||
intermediate_cache1.view(-1, N),
|
||||
activation_alpha,
|
||||
swiglu_limit,
|
||||
)
|
||||
elif _is_cuda:
|
||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
vllm_ops.silu_and_mul(
|
||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||
)
|
||||
elif activation == "gelu":
|
||||
assert (
|
||||
activation_alpha is None
|
||||
), "activation_alpha is not supported for gelu"
|
||||
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
|
||||
if _is_cuda:
|
||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||
else:
|
||||
@@ -1489,6 +1563,7 @@ def fused_experts_impl(
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
b2,
|
||||
(
|
||||
intermediate_cache3
|
||||
if not no_combine and topk_ids.shape[1] != 1
|
||||
@@ -1567,6 +1642,8 @@ def fused_moe(
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -1584,6 +1661,8 @@ def fused_moe(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
activation_alpha: Optional[float] = None,
|
||||
swiglu_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -1594,6 +1673,8 @@ def fused_moe(
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_output (TopKOutput): The top-k output of the experts.
|
||||
- b1 (Optional[torch.Tensor]): Optional bias for w1.
|
||||
- b2 (Optional[torch.Tensor]): Optional bias for w2.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
@@ -1615,6 +1696,10 @@ def fused_moe(
|
||||
a2.
|
||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
- activation_alpha (Optional[float]): Optional alpha for the activation
|
||||
function.
|
||||
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
|
||||
function.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
@@ -1625,6 +1710,8 @@ def fused_moe(
|
||||
w1,
|
||||
w2,
|
||||
topk_output,
|
||||
b1=b1,
|
||||
b2=b2,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
@@ -1642,4 +1729,6 @@ def fused_moe(
|
||||
block_shape=block_shape,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
@@ -199,7 +199,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||
self.use_triton_kernels, with_bias=with_bias
|
||||
self.use_triton_kernels
|
||||
)
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||
@@ -809,7 +809,9 @@ class FusedMoE(torch.nn.Module):
|
||||
# If we are in EP mode, we need to move the expert map to GPU.
|
||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||
|
||||
if self.expert_map_gpu is not None:
|
||||
if self.expert_map_gpu is not None and isinstance(
|
||||
topk_output, StandardTopKOutput
|
||||
):
|
||||
topk_output = topk_output._replace(
|
||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import triton.language as tl
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
@@ -24,6 +25,7 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_hip,
|
||||
is_triton_kernels_available,
|
||||
log_info_on_rank0,
|
||||
next_power_of_2,
|
||||
round_up,
|
||||
@@ -31,7 +33,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
has_triton_kernels = is_triton_kernels_available()
|
||||
|
||||
|
||||
if is_flashinfer_available():
|
||||
@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig):
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
use_flashinfer = global_server_args_dict.get(
|
||||
"enable_flashinfer_mxfp4_moe", False
|
||||
)
|
||||
return Mxfp4MoEMethod(
|
||||
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
|
||||
)
|
||||
return Mxfp4MoEMethod(prefix)
|
||||
else:
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_triton_kernels: bool = True,
|
||||
with_bias: bool = True,
|
||||
use_flashinfer: bool = False,
|
||||
prefix: str,
|
||||
):
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
self.with_bias = with_bias
|
||||
self.use_flashinfer = use_flashinfer
|
||||
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||
self.with_bias = False
|
||||
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
with_bias: bool = False,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# print(f"hi {self=} create_weights {layer=}")
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
self.with_bias = with_bias
|
||||
mxfp4_block = 32
|
||||
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype,
|
||||
@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w13_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
dtype=weight_dtype,
|
||||
@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
layer.num_local_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
dtype=scale_dtype,
|
||||
@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_weight_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16),
|
||||
torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
||||
@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
return
|
||||
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
if self.use_triton_kernels:
|
||||
|
||||
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
||||
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
||||
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
||||
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
||||
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
||||
|
||||
num_warps = 8
|
||||
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
||||
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
num_warps = 8
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||
)
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||
)
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||
)
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
else:
|
||||
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp
|
||||
|
||||
w13_weight = upcast_from_mxfp(
|
||||
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||
)
|
||||
w2_weight = upcast_from_mxfp(
|
||||
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||
)
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
del layer.w13_weight_scale
|
||||
del layer.w2_weight_scale
|
||||
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
self.num_experts,
|
||||
layer.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
0, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||
layer.num_local_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
1, # routing_method_type, renormalize
|
||||
@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
return trtllm_gen_output
|
||||
|
||||
if self.use_triton_kernels:
|
||||
assert (
|
||||
layer.moe_ep_size == 1
|
||||
), "Expert parallel is not supported when using triton kernels"
|
||||
if self.with_bias:
|
||||
# TODO why we do not put weights on layer?
|
||||
assert layer.w13_weight is None
|
||||
assert layer.w2_weight is None
|
||||
return self.triton_kernel_moe_with_bias_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
topk_output=topk_output,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
b1=layer.w13_weight_bias,
|
||||
b2=layer.w2_weight_bias,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
@@ -126,10 +126,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False):
|
||||
def __init__(self, use_triton_kernels: bool = False):
|
||||
super().__init__()
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
self.with_bias = with_bias
|
||||
self.with_bias = False
|
||||
|
||||
self.triton_kernel_moe_forward = None
|
||||
self.triton_kernel_moe_with_bias_forward = None
|
||||
@@ -151,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
with_bias: bool = False,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
self.with_bias = with_bias
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
||||
if self.use_triton_kernels:
|
||||
@@ -319,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
b1=getattr(layer, "w13_weight_bias", None),
|
||||
b2=getattr(layer, "w2_weight_bias", None),
|
||||
topk_output=topk_output,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
activation_alpha=activation_alpha,
|
||||
swiglu_limit=swiglu_limit,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
|
||||
@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_rank,
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_moe_tensor_parallel_rank,
|
||||
get_moe_tensor_parallel_world_size,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -96,11 +97,6 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
self.activation = config.hidden_act
|
||||
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
||||
self.swiglu_limit = config.swiglu_limit
|
||||
if self.tp_size > config.num_local_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_local_experts}."
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
|
||||
self.topk = None
|
||||
@@ -708,22 +704,26 @@ class GptOssForCausalLM(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
mxfp4_block = 32
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
moe_tp_rank = get_moe_tensor_parallel_rank()
|
||||
moe_tp_size = get_moe_tensor_parallel_world_size()
|
||||
moe_ep_rank = get_moe_expert_parallel_rank()
|
||||
moe_ep_size = get_moe_expert_parallel_world_size()
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
per_rank_intermediate_size_block = intermediate_size_block // tp_size
|
||||
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
|
||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
|
||||
|
||||
# Attention heads per rank
|
||||
heads_per_rank = self.config.num_attention_heads // tp_size
|
||||
head_start = tp_rank * heads_per_rank
|
||||
|
||||
num_experts = self.config.num_local_experts
|
||||
assert self.config.num_local_experts % moe_ep_size == 0
|
||||
moe_num_global_experts = self.config.num_local_experts
|
||||
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
|
||||
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
|
||||
moe_tp_rank_end = min(
|
||||
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
|
||||
)
|
||||
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
|
||||
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
|
||||
|
||||
for name, weight in weights:
|
||||
weight = weight.cuda()
|
||||
@@ -735,10 +735,14 @@ class GptOssForCausalLM(nn.Module):
|
||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||
weight = weight.view(
|
||||
num_experts, 2 * intermediate_size, -1
|
||||
moe_num_global_experts, 2 * intermediate_size, -1
|
||||
).contiguous()
|
||||
|
||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
|
||||
narrow_weight = weight[
|
||||
moe_ep_rank_start:moe_ep_rank_end,
|
||||
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
||||
...,
|
||||
]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
@@ -757,9 +761,13 @@ class GptOssForCausalLM(nn.Module):
|
||||
# same flatten here, but since 2 mx4 value are packed in 1
|
||||
# uint8, divide by 2
|
||||
weight = weight.view(
|
||||
num_experts, -1, intermediate_size // 2
|
||||
moe_num_global_experts, -1, intermediate_size // 2
|
||||
).contiguous()
|
||||
narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
|
||||
narrow_weight = weight[
|
||||
moe_ep_rank_start:moe_ep_rank_end,
|
||||
...,
|
||||
moe_tp_rank_start // 2 : moe_tp_rank_end // 2,
|
||||
]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
@@ -775,7 +783,11 @@ class GptOssForCausalLM(nn.Module):
|
||||
elif "gate_up_proj_scales" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
|
||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
|
||||
narrow_weight = weight[
|
||||
moe_ep_rank_start:moe_ep_rank_end,
|
||||
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
||||
...,
|
||||
]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
@@ -792,7 +804,9 @@ class GptOssForCausalLM(nn.Module):
|
||||
# Handle MLP down projection weights
|
||||
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||
narrow_weight = weight[
|
||||
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
|
||||
moe_ep_rank_start:moe_ep_rank_end,
|
||||
...,
|
||||
moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block,
|
||||
]
|
||||
|
||||
param = params_dict[new_name]
|
||||
@@ -809,7 +823,10 @@ class GptOssForCausalLM(nn.Module):
|
||||
# Handle MLP gate and up projection biases
|
||||
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
|
||||
|
||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
|
||||
narrow_weight = weight[
|
||||
moe_ep_rank_start:moe_ep_rank_end,
|
||||
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
||||
]
|
||||
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
@@ -823,15 +840,20 @@ class GptOssForCausalLM(nn.Module):
|
||||
loaded_params.add(new_name)
|
||||
|
||||
elif "down_proj_bias" in name:
|
||||
if get_moe_tensor_parallel_rank() != 0:
|
||||
weight = torch.zeros_like(weight)
|
||||
narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...]
|
||||
if moe_tp_rank != 0:
|
||||
narrow_weight = torch.zeros_like(narrow_weight)
|
||||
|
||||
# Handle MLP down projection bias
|
||||
new_name = name.replace("down_proj_bias", "w2_weight_bias")
|
||||
param = params_dict[new_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param, weight, weight_name=new_name, shard_id=None, expert_id=None
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=new_name,
|
||||
shard_id=None,
|
||||
expert_id=None,
|
||||
)
|
||||
loaded_params.add(new_name)
|
||||
|
||||
@@ -910,27 +932,12 @@ class GptOssForCausalLM(nn.Module):
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
if self.quant_config is not None and (self.quant_config.get_name() == "mxfp4"):
|
||||
expert_params_mapping = (
|
||||
get_moe_impl_class().make_expert_params_mapping_fused_mxfp4(
|
||||
ckpt_gate_up_proj_name="gate_up_proj_blocks",
|
||||
ckpt_down_proj_name="down_proj_blocks",
|
||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||
ckpt_down_proj_bias_name="down_proj_bias",
|
||||
ckpt_gate_up_proj_scale_name="gate_up_proj_scales",
|
||||
ckpt_down_proj_scale_name="down_proj_scales",
|
||||
)
|
||||
)
|
||||
else:
|
||||
expert_params_mapping = (
|
||||
get_moe_impl_class().make_expert_params_mapping_fused(
|
||||
ckpt_gate_up_proj_name="gate_up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||
ckpt_down_proj_bias_name="down_proj_bias",
|
||||
)
|
||||
)
|
||||
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
|
||||
ckpt_gate_up_proj_name="gate_up_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||
ckpt_down_proj_bias_name="down_proj_bias",
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
params_checker = {k: False for k, v in params_dict.items()}
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
is_port_available,
|
||||
is_remote_url,
|
||||
is_triton_kernels_available,
|
||||
is_valid_ipv6_address,
|
||||
nullable_str,
|
||||
)
|
||||
@@ -492,10 +493,15 @@ class ServerArgs:
|
||||
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||
)
|
||||
else:
|
||||
self.enable_triton_kernel_moe = True
|
||||
logger.info(
|
||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||
)
|
||||
if self.enable_triton_kernel_moe:
|
||||
assert (
|
||||
self.ep_size == 1
|
||||
), "Triton kernel MoE is only supported when ep_size == 1"
|
||||
if not self.enable_triton_kernel_moe and self.ep_size == 1:
|
||||
self.enable_triton_kernel_moe = True
|
||||
logger.info(
|
||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||
)
|
||||
|
||||
self.disable_hybrid_swa_memory = True
|
||||
|
||||
|
||||
@@ -2961,3 +2961,8 @@ class ConcurrentCounter:
|
||||
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
||||
"""
|
||||
self.wait_for(lambda count: count == 0)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_triton_kernels_available() -> bool:
|
||||
return importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
Reference in New Issue
Block a user