Expert Parallelism for GPT-OSS (#8944)
This commit is contained in:
@@ -76,6 +76,9 @@ class EPMoE(FusedMoE):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
|
with_bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_experts=num_experts,
|
num_experts=num_experts,
|
||||||
@@ -91,6 +94,9 @@ class EPMoE(FusedMoE):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
activation_alpha=activation_alpha,
|
||||||
|
swiglu_limit=swiglu_limit,
|
||||||
|
with_bias=with_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
||||||
|
|||||||
@@ -319,6 +319,7 @@ def fused_moe_kernel(
|
|||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
a_ptr,
|
a_ptr,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
|
bias_ptr,
|
||||||
c_ptr,
|
c_ptr,
|
||||||
a_scale_ptr,
|
a_scale_ptr,
|
||||||
b_scale_ptr,
|
b_scale_ptr,
|
||||||
@@ -340,6 +341,8 @@ def fused_moe_kernel(
|
|||||||
stride_be,
|
stride_be,
|
||||||
stride_bk,
|
stride_bk,
|
||||||
stride_bn,
|
stride_bn,
|
||||||
|
stride_bias_e,
|
||||||
|
stride_bias_n,
|
||||||
stride_cm,
|
stride_cm,
|
||||||
stride_cn,
|
stride_cn,
|
||||||
stride_asm,
|
stride_asm,
|
||||||
@@ -449,6 +452,10 @@ def fused_moe_kernel(
|
|||||||
+ off_experts * stride_be
|
+ off_experts * stride_be
|
||||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||||
)
|
)
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias = tl.load(
|
||||||
|
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
|
||||||
|
)
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
b_scale_ptrs = (
|
b_scale_ptrs = (
|
||||||
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
||||||
@@ -526,18 +533,20 @@ def fused_moe_kernel(
|
|||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
if use_int8_w8a16:
|
||||||
|
accumulator *= b_scale
|
||||||
|
elif use_fp8_w8a8 or use_int8_w8a8:
|
||||||
|
if group_k == 0 or group_n == 0:
|
||||||
|
accumulator *= a_scale * b_scale
|
||||||
|
|
||||||
|
if bias_ptr is not None:
|
||||||
|
accumulator += bias
|
||||||
|
|
||||||
if MUL_ROUTED_WEIGHT:
|
if MUL_ROUTED_WEIGHT:
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator *= moe_weight[:, None]
|
||||||
if use_int8_w8a16:
|
|
||||||
accumulator = (accumulator * b_scale).to(compute_type)
|
accumulator = accumulator.to(compute_type)
|
||||||
elif use_fp8_w8a8 or use_int8_w8a8:
|
|
||||||
if group_k > 0 and group_n > 0:
|
|
||||||
accumulator = accumulator.to(compute_type)
|
|
||||||
else:
|
|
||||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
|
||||||
else:
|
|
||||||
accumulator = accumulator.to(compute_type)
|
|
||||||
# -----------------------------------------------------------
|
# -----------------------------------------------------------
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
@@ -622,6 +631,7 @@ def moe_align_block_size(
|
|||||||
def invoke_fused_moe_kernel(
|
def invoke_fused_moe_kernel(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
C: torch.Tensor,
|
C: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
A_scale: Optional[torch.Tensor],
|
||||||
B_scale: Optional[torch.Tensor],
|
B_scale: Optional[torch.Tensor],
|
||||||
@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
|
|||||||
):
|
):
|
||||||
assert B_scale is not None and B_scale.ndim == 3
|
assert B_scale is not None and B_scale.ndim == 3
|
||||||
assert B_zp is None or B_zp.ndim == 3
|
assert B_zp is None or B_zp.ndim == 3
|
||||||
|
assert bias is None
|
||||||
fused_moe_kernel_gptq_awq[grid](
|
fused_moe_kernel_gptq_awq[grid](
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
|
|||||||
fused_moe_kernel[grid](
|
fused_moe_kernel[grid](
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
|
bias,
|
||||||
C,
|
C,
|
||||||
A_scale,
|
A_scale,
|
||||||
B_scale,
|
B_scale,
|
||||||
@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
|
|||||||
B.stride(0),
|
B.stride(0),
|
||||||
B.stride(2),
|
B.stride(2),
|
||||||
B.stride(1),
|
B.stride(1),
|
||||||
|
bias.stride(0) if bias is not None else 0,
|
||||||
|
bias.stride(1) if bias is not None else 0,
|
||||||
C.stride(1),
|
C.stride(1),
|
||||||
C.stride(2),
|
C.stride(2),
|
||||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||||
@@ -994,6 +1008,8 @@ def inplace_fused_experts(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
b1: Optional[torch.Tensor] = None,
|
||||||
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
fused_experts_impl(
|
fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
b1,
|
||||||
|
b2,
|
||||||
True,
|
True,
|
||||||
activation,
|
activation,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
|
|||||||
block_shape,
|
block_shape,
|
||||||
False,
|
False,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
|
activation_alpha,
|
||||||
|
swiglu_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
b1: Optional[torch.Tensor] = None,
|
||||||
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
b1: Optional[torch.Tensor] = None,
|
||||||
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return fused_experts_impl(
|
return fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
b1,
|
||||||
|
b2,
|
||||||
False,
|
False,
|
||||||
activation,
|
activation,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
|
|||||||
block_shape,
|
block_shape,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
activation_alpha=activation_alpha,
|
||||||
|
swiglu_limit=swiglu_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
b1: Optional[torch.Tensor] = None,
|
||||||
|
b2: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
@@ -1157,6 +1195,8 @@ def fused_experts(
|
|||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
|
b1: Optional[torch.Tensor] = None,
|
||||||
|
b2: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
@@ -1174,6 +1214,8 @@ def fused_experts(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
):
|
):
|
||||||
topk_weights, topk_ids, _ = topk_output
|
topk_weights, topk_ids, _ = topk_output
|
||||||
if inplace:
|
if inplace:
|
||||||
@@ -1184,6 +1226,8 @@ def fused_experts(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
b1,
|
||||||
|
b2,
|
||||||
activation,
|
activation,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
@@ -1199,6 +1243,8 @@ def fused_experts(
|
|||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
|
activation_alpha,
|
||||||
|
swiglu_limit,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -1208,6 +1254,8 @@ def fused_experts(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
b1,
|
||||||
|
b2,
|
||||||
activation,
|
activation,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
@@ -1224,6 +1272,8 @@ def fused_experts(
|
|||||||
block_shape,
|
block_shape,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
activation_alpha=activation_alpha,
|
||||||
|
swiglu_limit=swiglu_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
|
|||||||
out.mul_(routed_scaling_factor)
|
out.mul_(routed_scaling_factor)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def swiglu_with_alpha_and_limit(x, alpha, limit):
|
||||||
|
gate, up = x[..., ::2], x[..., 1::2]
|
||||||
|
gate = gate.clamp(min=None, max=limit)
|
||||||
|
up = up.clamp(min=-limit, max=limit)
|
||||||
|
return gate * torch.sigmoid(gate * alpha) * (up + 1)
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(
|
def fused_experts_impl(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
b1: Optional[torch.Tensor] = None,
|
||||||
|
b2: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
@@ -1342,6 +1402,8 @@ def fused_experts_impl(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
):
|
):
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||||
@@ -1353,7 +1415,7 @@ def fused_experts_impl(
|
|||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
hidden_states.shape[1] == w1.shape[2] - padded_size
|
hidden_states.shape[1] == w1.shape[2] - padded_size
|
||||||
), "Hidden size mismatch"
|
), f"Hidden size mismatch"
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
@@ -1449,6 +1511,7 @@ def fused_experts_impl(
|
|||||||
invoke_fused_moe_kernel(
|
invoke_fused_moe_kernel(
|
||||||
curr_hidden_states,
|
curr_hidden_states,
|
||||||
w1,
|
w1,
|
||||||
|
b1,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
a1_scale,
|
a1_scale,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
@@ -1470,13 +1533,24 @@ def fused_experts_impl(
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
if activation == "silu":
|
if activation == "silu":
|
||||||
if _is_cuda:
|
if activation_alpha is not None:
|
||||||
|
assert swiglu_limit is not None
|
||||||
|
intermediate_cache2 = swiglu_with_alpha_and_limit(
|
||||||
|
intermediate_cache1.view(-1, N),
|
||||||
|
activation_alpha,
|
||||||
|
swiglu_limit,
|
||||||
|
)
|
||||||
|
elif _is_cuda:
|
||||||
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||||
else:
|
else:
|
||||||
vllm_ops.silu_and_mul(
|
vllm_ops.silu_and_mul(
|
||||||
intermediate_cache2, intermediate_cache1.view(-1, N)
|
intermediate_cache2, intermediate_cache1.view(-1, N)
|
||||||
)
|
)
|
||||||
elif activation == "gelu":
|
elif activation == "gelu":
|
||||||
|
assert (
|
||||||
|
activation_alpha is None
|
||||||
|
), "activation_alpha is not supported for gelu"
|
||||||
|
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
||||||
else:
|
else:
|
||||||
@@ -1489,6 +1563,7 @@ def fused_experts_impl(
|
|||||||
invoke_fused_moe_kernel(
|
invoke_fused_moe_kernel(
|
||||||
intermediate_cache2,
|
intermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
|
b2,
|
||||||
(
|
(
|
||||||
intermediate_cache3
|
intermediate_cache3
|
||||||
if not no_combine and topk_ids.shape[1] != 1
|
if not no_combine and topk_ids.shape[1] != 1
|
||||||
@@ -1567,6 +1642,8 @@ def fused_moe(
|
|||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
|
b1: Optional[torch.Tensor] = None,
|
||||||
|
b2: Optional[torch.Tensor] = None,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
@@ -1584,6 +1661,8 @@ def fused_moe(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
|
activation_alpha: Optional[float] = None,
|
||||||
|
swiglu_limit: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@@ -1594,6 +1673,8 @@ def fused_moe(
|
|||||||
- w1 (torch.Tensor): The first set of expert weights.
|
- w1 (torch.Tensor): The first set of expert weights.
|
||||||
- w2 (torch.Tensor): The second set of expert weights.
|
- w2 (torch.Tensor): The second set of expert weights.
|
||||||
- topk_output (TopKOutput): The top-k output of the experts.
|
- topk_output (TopKOutput): The top-k output of the experts.
|
||||||
|
- b1 (Optional[torch.Tensor]): Optional bias for w1.
|
||||||
|
- b2 (Optional[torch.Tensor]): Optional bias for w2.
|
||||||
- inplace (bool): If True, perform the operation in-place.
|
- inplace (bool): If True, perform the operation in-place.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
@@ -1615,6 +1696,10 @@ def fused_moe(
|
|||||||
a2.
|
a2.
|
||||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||||
quantization.
|
quantization.
|
||||||
|
- activation_alpha (Optional[float]): Optional alpha for the activation
|
||||||
|
function.
|
||||||
|
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
|
||||||
|
function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
@@ -1625,6 +1710,8 @@ def fused_moe(
|
|||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
topk_output,
|
topk_output,
|
||||||
|
b1=b1,
|
||||||
|
b2=b2,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
@@ -1642,4 +1729,6 @@ def fused_moe(
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
activation_alpha=activation_alpha,
|
||||||
|
swiglu_limit=swiglu_limit,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
|
||||||
self.use_triton_kernels, with_bias=with_bias
|
self.use_triton_kernels
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.quant_method = quant_config.get_quant_method(self, prefix)
|
self.quant_method = quant_config.get_quant_method(self, prefix)
|
||||||
@@ -809,7 +809,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
# If we are in EP mode, we need to move the expert map to GPU.
|
# If we are in EP mode, we need to move the expert map to GPU.
|
||||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||||
|
|
||||||
if self.expert_map_gpu is not None:
|
if self.expert_map_gpu is not None and isinstance(
|
||||||
|
topk_output, StandardTopKOutput
|
||||||
|
):
|
||||||
topk_output = topk_output._replace(
|
topk_output = topk_output._replace(
|
||||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import logging
|
|||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import triton.language as tl
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -24,6 +25,7 @@ from sglang.srt.utils import (
|
|||||||
is_cuda,
|
is_cuda,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
|
is_triton_kernels_available,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
next_power_of_2,
|
next_power_of_2,
|
||||||
round_up,
|
round_up,
|
||||||
@@ -31,7 +33,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
has_triton_kernels = is_triton_kernels_available()
|
||||||
|
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
@@ -188,12 +190,7 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
):
|
):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
use_flashinfer = global_server_args_dict.get(
|
return Mxfp4MoEMethod(prefix)
|
||||||
"enable_flashinfer_mxfp4_moe", False
|
|
||||||
)
|
|
||||||
return Mxfp4MoEMethod(
|
|
||||||
use_triton_kernels=True, with_bias=True, use_flashinfer=use_flashinfer
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||||
return None
|
return None
|
||||||
@@ -206,15 +203,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_triton_kernels: bool = True,
|
prefix: str,
|
||||||
with_bias: bool = True,
|
|
||||||
use_flashinfer: bool = False,
|
|
||||||
):
|
):
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
self.use_triton_kernels = use_triton_kernels
|
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
|
||||||
self.with_bias = with_bias
|
self.with_bias = False
|
||||||
self.use_flashinfer = use_flashinfer
|
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
|
||||||
|
|
||||||
self.triton_kernel_moe_forward = None
|
self.triton_kernel_moe_forward = None
|
||||||
self.triton_kernel_moe_with_bias_forward = None
|
self.triton_kernel_moe_with_bias_forward = None
|
||||||
@@ -236,12 +234,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
|
with_bias: bool = False,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
# print(f"hi {self=} create_weights {layer=}")
|
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
weight_dtype = torch.uint8
|
weight_dtype = torch.uint8
|
||||||
scale_dtype = torch.uint8
|
scale_dtype = torch.uint8
|
||||||
|
self.with_bias = with_bias
|
||||||
mxfp4_block = 32
|
mxfp4_block = 32
|
||||||
|
|
||||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||||
@@ -264,7 +263,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
# Fused gate_up_proj (column parallel)
|
# Fused gate_up_proj (column parallel)
|
||||||
w13_weight = torch.nn.Parameter(
|
w13_weight = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts,
|
layer.num_local_experts,
|
||||||
2 * intermediate_size_per_partition_after_pad,
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
hidden_size // 2,
|
hidden_size // 2,
|
||||||
dtype=weight_dtype,
|
dtype=weight_dtype,
|
||||||
@@ -276,7 +275,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
w13_weight_scale = torch.nn.Parameter(
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts,
|
layer.num_local_experts,
|
||||||
2 * intermediate_size_per_partition_after_pad,
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
hidden_size // mxfp4_block,
|
hidden_size // mxfp4_block,
|
||||||
dtype=scale_dtype,
|
dtype=scale_dtype,
|
||||||
@@ -288,7 +287,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
w13_weight_bias = torch.nn.Parameter(
|
w13_weight_bias = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts,
|
layer.num_local_experts,
|
||||||
2 * intermediate_size_per_partition_after_pad,
|
2 * intermediate_size_per_partition_after_pad,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
),
|
),
|
||||||
@@ -300,7 +299,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
# down_proj (row parallel)
|
# down_proj (row parallel)
|
||||||
w2_weight = torch.nn.Parameter(
|
w2_weight = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts,
|
layer.num_local_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
intermediate_size_per_partition_after_pad // 2,
|
intermediate_size_per_partition_after_pad // 2,
|
||||||
dtype=weight_dtype,
|
dtype=weight_dtype,
|
||||||
@@ -312,7 +311,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
w2_weight_scale = torch.nn.Parameter(
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
num_experts,
|
layer.num_local_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||||
dtype=scale_dtype,
|
dtype=scale_dtype,
|
||||||
@@ -323,7 +322,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
w2_weight_bias = torch.nn.Parameter(
|
w2_weight_bias = torch.nn.Parameter(
|
||||||
torch.zeros(num_experts, hidden_size, dtype=torch.bfloat16),
|
torch.zeros(layer.num_local_experts, hidden_size, dtype=torch.bfloat16),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
layer.register_parameter("w2_weight_bias", w2_weight_bias)
|
||||||
@@ -484,38 +483,51 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
if self.use_triton_kernels:
|
||||||
|
|
||||||
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||||
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
|
||||||
|
|
||||||
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
w13_weight_bias = layer.w13_weight_bias.to(torch.float32)
|
||||||
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
w2_weight_bias = layer.w2_weight_bias.to(torch.float32)
|
||||||
|
|
||||||
num_warps = 8
|
layer.w13_weight_bias = Parameter(w13_weight_bias, requires_grad=False)
|
||||||
|
layer.w2_weight_bias = Parameter(w2_weight_bias, requires_grad=False)
|
||||||
|
|
||||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
num_warps = 8
|
||||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
|
||||||
)
|
|
||||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
|
||||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
|
||||||
)
|
|
||||||
|
|
||||||
self.w13_precision_config = PrecisionConfig(
|
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||||
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||||
)
|
)
|
||||||
self.w2_precision_config = PrecisionConfig(
|
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||||
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||||
)
|
)
|
||||||
|
|
||||||
self.w13_weight_triton_tensor = w13_weight
|
self.w13_precision_config = PrecisionConfig(
|
||||||
self.w2_weight_triton_tensor = w2_weight
|
weight_scale=w13_scale, flex_ctx=FlexCtx(rhs_data=w13_flex)
|
||||||
|
)
|
||||||
|
self.w2_precision_config = PrecisionConfig(
|
||||||
|
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
|
||||||
|
)
|
||||||
|
|
||||||
# need to delete the original weights to save memory on single GPU
|
self.w13_weight_triton_tensor = w13_weight
|
||||||
del layer.w13_weight
|
self.w2_weight_triton_tensor = w2_weight
|
||||||
del layer.w2_weight
|
del layer.w13_weight
|
||||||
layer.w13_weight = None
|
del layer.w2_weight
|
||||||
layer.w2_weight = None
|
else:
|
||||||
|
from triton_kernels.numerics_details.mxfp import upcast_from_mxfp
|
||||||
|
|
||||||
|
w13_weight = upcast_from_mxfp(
|
||||||
|
layer.w13_weight, layer.w13_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||||
|
)
|
||||||
|
w2_weight = upcast_from_mxfp(
|
||||||
|
layer.w2_weight, layer.w2_weight_scale, dtype=torch.bfloat16, axis=-1
|
||||||
|
)
|
||||||
|
del layer.w13_weight
|
||||||
|
del layer.w2_weight
|
||||||
|
del layer.w13_weight_scale
|
||||||
|
del layer.w2_weight_scale
|
||||||
|
layer.w13_weight = Parameter(w13_weight.data, requires_grad=False)
|
||||||
|
layer.w2_weight = Parameter(w2_weight.data, requires_grad=False)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||||
@@ -580,13 +592,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
None, # output1_scale_scalar
|
None, # output1_scale_scalar
|
||||||
None, # output1_scale_gate_scalar
|
None, # output1_scale_gate_scalar
|
||||||
None, # output2_scale_scalar
|
None, # output2_scale_scalar
|
||||||
self.num_experts,
|
layer.num_experts,
|
||||||
top_k,
|
top_k,
|
||||||
None, # n_group
|
None, # n_group
|
||||||
None, # topk_group
|
None, # topk_group
|
||||||
self.intermediate_size, # padded to multiple of 256
|
self.intermediate_size, # padded to multiple of 256
|
||||||
0, # local_expert_offset
|
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||||
self.num_experts, # local num experts
|
layer.num_local_experts, # local num experts
|
||||||
None,
|
None,
|
||||||
self._get_tile_tokens_dim(x, top_k),
|
self._get_tile_tokens_dim(x, top_k),
|
||||||
1, # routing_method_type, renormalize
|
1, # routing_method_type, renormalize
|
||||||
@@ -595,10 +607,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
return trtllm_gen_output
|
return trtllm_gen_output
|
||||||
|
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
|
assert (
|
||||||
|
layer.moe_ep_size == 1
|
||||||
|
), "Expert parallel is not supported when using triton kernels"
|
||||||
if self.with_bias:
|
if self.with_bias:
|
||||||
# TODO why we do not put weights on layer?
|
|
||||||
assert layer.w13_weight is None
|
|
||||||
assert layer.w2_weight is None
|
|
||||||
return self.triton_kernel_moe_with_bias_forward(
|
return self.triton_kernel_moe_with_bias_forward(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=self.w13_weight_triton_tensor,
|
w1=self.w13_weight_triton_tensor,
|
||||||
@@ -620,4 +632,20 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_output=topk_output,
|
||||||
|
b1=layer.w13_weight_bias,
|
||||||
|
b2=layer.w2_weight_bias,
|
||||||
|
inplace=inplace,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
activation_alpha=activation_alpha,
|
||||||
|
swiglu_limit=swiglu_limit,
|
||||||
|
)
|
||||||
|
|||||||
@@ -126,10 +126,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False):
|
def __init__(self, use_triton_kernels: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_triton_kernels = use_triton_kernels
|
self.use_triton_kernels = use_triton_kernels
|
||||||
self.with_bias = with_bias
|
self.with_bias = False
|
||||||
|
|
||||||
self.triton_kernel_moe_forward = None
|
self.triton_kernel_moe_forward = None
|
||||||
self.triton_kernel_moe_with_bias_forward = None
|
self.triton_kernel_moe_with_bias_forward = None
|
||||||
@@ -151,8 +151,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
|
with_bias: bool = False,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
|
self.with_bias = with_bias
|
||||||
|
|
||||||
# Fused gate_up_proj (column parallel)
|
# Fused gate_up_proj (column parallel)
|
||||||
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
@@ -319,12 +322,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
|
b1=getattr(layer, "w13_weight_bias", None),
|
||||||
|
b2=getattr(layer, "w2_weight_bias", None),
|
||||||
topk_output=topk_output,
|
topk_output=topk_output,
|
||||||
inplace=inplace and not no_combine,
|
inplace=inplace and not no_combine,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
activation_alpha=activation_alpha,
|
||||||
|
swiglu_limit=swiglu_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from sglang.srt.distributed import (
|
|||||||
get_moe_expert_parallel_rank,
|
get_moe_expert_parallel_rank,
|
||||||
get_moe_expert_parallel_world_size,
|
get_moe_expert_parallel_world_size,
|
||||||
get_moe_tensor_parallel_rank,
|
get_moe_tensor_parallel_rank,
|
||||||
|
get_moe_tensor_parallel_world_size,
|
||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -96,11 +97,6 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
self.activation = config.hidden_act
|
self.activation = config.hidden_act
|
||||||
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
|
||||||
self.swiglu_limit = config.swiglu_limit
|
self.swiglu_limit = config.swiglu_limit
|
||||||
if self.tp_size > config.num_local_experts:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
|
||||||
f"the number of experts {config.num_local_experts}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
|
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
|
||||||
self.topk = None
|
self.topk = None
|
||||||
@@ -708,22 +704,26 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
mxfp4_block = 32
|
mxfp4_block = 32
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
moe_tp_rank = get_moe_tensor_parallel_rank()
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
moe_tp_size = get_moe_tensor_parallel_world_size()
|
||||||
|
moe_ep_rank = get_moe_expert_parallel_rank()
|
||||||
|
moe_ep_size = get_moe_expert_parallel_world_size()
|
||||||
|
|
||||||
intermediate_size = self.config.intermediate_size
|
intermediate_size = self.config.intermediate_size
|
||||||
intermediate_size_block = intermediate_size // mxfp4_block
|
intermediate_size_block = intermediate_size // mxfp4_block
|
||||||
per_rank_intermediate_size_block = intermediate_size_block // tp_size
|
per_rank_intermediate_size_block = intermediate_size_block // moe_tp_size
|
||||||
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block
|
||||||
|
|
||||||
# Calculate common slicing bounds for current rank
|
# Calculate common slicing bounds for current rank
|
||||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
assert self.config.num_local_experts % moe_ep_size == 0
|
||||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
|
moe_num_global_experts = self.config.num_local_experts
|
||||||
|
moe_num_local_experts = self.config.num_local_experts // moe_ep_size
|
||||||
# Attention heads per rank
|
moe_tp_rank_start = moe_tp_rank * per_rank_intermediate_size
|
||||||
heads_per_rank = self.config.num_attention_heads // tp_size
|
moe_tp_rank_end = min(
|
||||||
head_start = tp_rank * heads_per_rank
|
(moe_tp_rank + 1) * per_rank_intermediate_size, intermediate_size
|
||||||
|
)
|
||||||
num_experts = self.config.num_local_experts
|
moe_ep_rank_start = moe_ep_rank * moe_num_local_experts
|
||||||
|
moe_ep_rank_end = (moe_ep_rank + 1) * moe_num_local_experts
|
||||||
|
|
||||||
for name, weight in weights:
|
for name, weight in weights:
|
||||||
weight = weight.cuda()
|
weight = weight.cuda()
|
||||||
@@ -735,10 +735,14 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
# flat weight from (E, 2 * N, block_size, entry_per_block)
|
||||||
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
# to (E, 2 * N, -1), shouldn't trigger copy for contiguous
|
||||||
weight = weight.view(
|
weight = weight.view(
|
||||||
num_experts, 2 * intermediate_size, -1
|
moe_num_global_experts, 2 * intermediate_size, -1
|
||||||
).contiguous()
|
).contiguous()
|
||||||
|
|
||||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
|
narrow_weight = weight[
|
||||||
|
moe_ep_rank_start:moe_ep_rank_end,
|
||||||
|
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
||||||
|
...,
|
||||||
|
]
|
||||||
|
|
||||||
param = params_dict[new_name]
|
param = params_dict[new_name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
@@ -757,9 +761,13 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
# same flatten here, but since 2 mx4 value are packed in 1
|
# same flatten here, but since 2 mx4 value are packed in 1
|
||||||
# uint8, divide by 2
|
# uint8, divide by 2
|
||||||
weight = weight.view(
|
weight = weight.view(
|
||||||
num_experts, -1, intermediate_size // 2
|
moe_num_global_experts, -1, intermediate_size // 2
|
||||||
).contiguous()
|
).contiguous()
|
||||||
narrow_weight = weight[..., tp_rank_start // 2 : tp_rank_end // 2]
|
narrow_weight = weight[
|
||||||
|
moe_ep_rank_start:moe_ep_rank_end,
|
||||||
|
...,
|
||||||
|
moe_tp_rank_start // 2 : moe_tp_rank_end // 2,
|
||||||
|
]
|
||||||
|
|
||||||
param = params_dict[new_name]
|
param = params_dict[new_name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
@@ -775,7 +783,11 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
elif "gate_up_proj_scales" in name:
|
elif "gate_up_proj_scales" in name:
|
||||||
# Handle MLP gate and up projection weights scale
|
# Handle MLP gate and up projection weights scale
|
||||||
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
|
new_name = name.replace("gate_up_proj_scales", "w13_weight_scale")
|
||||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end, ...]
|
narrow_weight = weight[
|
||||||
|
moe_ep_rank_start:moe_ep_rank_end,
|
||||||
|
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
||||||
|
...,
|
||||||
|
]
|
||||||
|
|
||||||
param = params_dict[new_name]
|
param = params_dict[new_name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
@@ -792,7 +804,9 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
# Handle MLP down projection weights
|
# Handle MLP down projection weights
|
||||||
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
new_name = name.replace("down_proj_scales", "w2_weight_scale")
|
||||||
narrow_weight = weight[
|
narrow_weight = weight[
|
||||||
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
|
moe_ep_rank_start:moe_ep_rank_end,
|
||||||
|
...,
|
||||||
|
moe_tp_rank_start // mxfp4_block : moe_tp_rank_end // mxfp4_block,
|
||||||
]
|
]
|
||||||
|
|
||||||
param = params_dict[new_name]
|
param = params_dict[new_name]
|
||||||
@@ -809,7 +823,10 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
# Handle MLP gate and up projection biases
|
# Handle MLP gate and up projection biases
|
||||||
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
|
new_name = name.replace("gate_up_proj_bias", "w13_weight_bias")
|
||||||
|
|
||||||
narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
|
narrow_weight = weight[
|
||||||
|
moe_ep_rank_start:moe_ep_rank_end,
|
||||||
|
2 * moe_tp_rank_start : 2 * moe_tp_rank_end,
|
||||||
|
]
|
||||||
|
|
||||||
param = params_dict[new_name]
|
param = params_dict[new_name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
@@ -823,15 +840,20 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
loaded_params.add(new_name)
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
elif "down_proj_bias" in name:
|
elif "down_proj_bias" in name:
|
||||||
if get_moe_tensor_parallel_rank() != 0:
|
narrow_weight = weight[moe_ep_rank_start:moe_ep_rank_end, ...]
|
||||||
weight = torch.zeros_like(weight)
|
if moe_tp_rank != 0:
|
||||||
|
narrow_weight = torch.zeros_like(narrow_weight)
|
||||||
|
|
||||||
# Handle MLP down projection bias
|
# Handle MLP down projection bias
|
||||||
new_name = name.replace("down_proj_bias", "w2_weight_bias")
|
new_name = name.replace("down_proj_bias", "w2_weight_bias")
|
||||||
param = params_dict[new_name]
|
param = params_dict[new_name]
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(
|
weight_loader(
|
||||||
param, weight, weight_name=new_name, shard_id=None, expert_id=None
|
param,
|
||||||
|
narrow_weight,
|
||||||
|
weight_name=new_name,
|
||||||
|
shard_id=None,
|
||||||
|
expert_id=None,
|
||||||
)
|
)
|
||||||
loaded_params.add(new_name)
|
loaded_params.add(new_name)
|
||||||
|
|
||||||
@@ -910,27 +932,12 @@ class GptOssForCausalLM(nn.Module):
|
|||||||
("qkv_proj", "k_proj", "k"),
|
("qkv_proj", "k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
]
|
]
|
||||||
|
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
|
||||||
if self.quant_config is not None and (self.quant_config.get_name() == "mxfp4"):
|
ckpt_gate_up_proj_name="gate_up_proj",
|
||||||
expert_params_mapping = (
|
ckpt_down_proj_name="down_proj",
|
||||||
get_moe_impl_class().make_expert_params_mapping_fused_mxfp4(
|
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
||||||
ckpt_gate_up_proj_name="gate_up_proj_blocks",
|
ckpt_down_proj_bias_name="down_proj_bias",
|
||||||
ckpt_down_proj_name="down_proj_blocks",
|
)
|
||||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
|
||||||
ckpt_down_proj_bias_name="down_proj_bias",
|
|
||||||
ckpt_gate_up_proj_scale_name="gate_up_proj_scales",
|
|
||||||
ckpt_down_proj_scale_name="down_proj_scales",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
expert_params_mapping = (
|
|
||||||
get_moe_impl_class().make_expert_params_mapping_fused(
|
|
||||||
ckpt_gate_up_proj_name="gate_up_proj",
|
|
||||||
ckpt_down_proj_name="down_proj",
|
|
||||||
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
|
|
||||||
ckpt_down_proj_bias_name="down_proj_bias",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
params_checker = {k: False for k, v in params_dict.items()}
|
params_checker = {k: False for k, v in params_dict.items()}
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
is_port_available,
|
is_port_available,
|
||||||
is_remote_url,
|
is_remote_url,
|
||||||
|
is_triton_kernels_available,
|
||||||
is_valid_ipv6_address,
|
is_valid_ipv6_address,
|
||||||
nullable_str,
|
nullable_str,
|
||||||
)
|
)
|
||||||
@@ -492,10 +493,15 @@ class ServerArgs:
|
|||||||
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.enable_triton_kernel_moe = True
|
if self.enable_triton_kernel_moe:
|
||||||
logger.info(
|
assert (
|
||||||
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
self.ep_size == 1
|
||||||
)
|
), "Triton kernel MoE is only supported when ep_size == 1"
|
||||||
|
if not self.enable_triton_kernel_moe and self.ep_size == 1:
|
||||||
|
self.enable_triton_kernel_moe = True
|
||||||
|
logger.info(
|
||||||
|
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
||||||
|
)
|
||||||
|
|
||||||
self.disable_hybrid_swa_memory = True
|
self.disable_hybrid_swa_memory = True
|
||||||
|
|
||||||
|
|||||||
@@ -2961,3 +2961,8 @@ class ConcurrentCounter:
|
|||||||
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
||||||
"""
|
"""
|
||||||
self.wait_for(lambda count: count == 0)
|
self.wait_for(lambda count: count == 0)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def is_triton_kernels_available() -> bool:
|
||||||
|
return importlib.util.find_spec("triton_kernels") is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user