|
|
|
|
@@ -11,7 +11,7 @@ _is_hip = is_hip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit
|
|
|
|
|
def fused_moe_router_kernel(
|
|
|
|
|
def fused_moe_router_cudacore_kernel(
|
|
|
|
|
input_ptr, # input (bs, hidden_dim)
|
|
|
|
|
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
|
|
|
|
topk_weights_ptr, # output (bs, topk)
|
|
|
|
|
@@ -114,7 +114,7 @@ def fused_moe_router_kernel(
|
|
|
|
|
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fused_moe_router_impl(
|
|
|
|
|
def fused_moe_router_cudacore(
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
router_weight: torch.Tensor,
|
|
|
|
|
topk: int,
|
|
|
|
|
@@ -138,7 +138,7 @@ def fused_moe_router_impl(
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fused_moe_router_kernel[(bs,)](
|
|
|
|
|
fused_moe_router_cudacore_kernel[(bs,)](
|
|
|
|
|
x,
|
|
|
|
|
router_weight,
|
|
|
|
|
topk_weights,
|
|
|
|
|
@@ -157,7 +157,7 @@ def fused_moe_router_impl(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit
|
|
|
|
|
def fused_moe_router_large_bs_kernel(
|
|
|
|
|
def fused_moe_router_tensorcore_kernel(
|
|
|
|
|
a_ptr, # input (bs, hidden_dim)
|
|
|
|
|
b_ptr, # input (num_experts, hidden_dim)
|
|
|
|
|
topk_weights_ptr, # output (bs, topk)
|
|
|
|
|
@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel(
|
|
|
|
|
topk: tl.constexpr, # only support topk <= 2
|
|
|
|
|
moe_softcapping: tl.constexpr,
|
|
|
|
|
moe_renormalize: tl.constexpr, # not supported
|
|
|
|
|
correction_bias_ptr,
|
|
|
|
|
is_correction_bias: tl.constexpr,
|
|
|
|
|
K: tl.constexpr,
|
|
|
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
|
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
|
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
|
|
|
stride_am: tl.constexpr,
|
|
|
|
|
stride_bn: tl.constexpr,
|
|
|
|
|
dp_attn_workaround_flag: tl.constexpr,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
# 1. get block id
|
|
|
|
|
@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel(
|
|
|
|
|
exped = tl.exp(2 * logits_scaled)
|
|
|
|
|
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
|
|
|
|
|
|
|
|
|
# Add bias after softcapping
|
|
|
|
|
if is_correction_bias:
|
|
|
|
|
bias = tl.load(
|
|
|
|
|
correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :],
|
|
|
|
|
mask=expert_mask.T,
|
|
|
|
|
other=0.0,
|
|
|
|
|
)
|
|
|
|
|
logits_softcapped = logits_softcapped + bias
|
|
|
|
|
|
|
|
|
|
if dp_attn_workaround_flag:
|
|
|
|
|
logits_softcapped = tl.where(
|
|
|
|
|
logits_softcapped != logits_softcapped, -1e9, logits_softcapped
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 5. top1
|
|
|
|
|
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
|
|
|
|
cond_top1 = arange_block_size_n < num_experts
|
|
|
|
|
@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel(
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fused_moe_router_large_bs_impl(
|
|
|
|
|
def fused_moe_router_tensorcore(
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
router_weight: torch.Tensor,
|
|
|
|
|
topk: int,
|
|
|
|
|
@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl(
|
|
|
|
|
BLOCK_SIZE_M: int,
|
|
|
|
|
BLOCK_SIZE_N: int,
|
|
|
|
|
BLOCK_SIZE_K: int,
|
|
|
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
|
|
|
):
|
|
|
|
|
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
|
|
|
|
bs, hidden_dim = x.shape
|
|
|
|
|
@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl(
|
|
|
|
|
|
|
|
|
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
|
|
|
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
|
|
|
|
is_correction_bias = correction_bias is not None
|
|
|
|
|
|
|
|
|
|
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
|
|
|
|
|
|
|
|
|
fused_moe_router_large_bs_kernel[grid](
|
|
|
|
|
# TODO(ch-wan): temporary workaround for dp attention. We should support masked
|
|
|
|
|
# router to skip padded tokens.
|
|
|
|
|
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
|
|
|
|
|
|
|
|
|
dp_attn_workaround_flag = is_dp_attention_enabled()
|
|
|
|
|
|
|
|
|
|
fused_moe_router_tensorcore_kernel[grid](
|
|
|
|
|
a_ptr=x,
|
|
|
|
|
b_ptr=router_weight,
|
|
|
|
|
topk_weights_ptr=topk_weights,
|
|
|
|
|
@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl(
|
|
|
|
|
moe_softcapping=moe_softcapping,
|
|
|
|
|
moe_renormalize=False,
|
|
|
|
|
K=hidden_dim,
|
|
|
|
|
correction_bias_ptr=correction_bias,
|
|
|
|
|
is_correction_bias=is_correction_bias,
|
|
|
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
|
|
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
|
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
|
|
|
stride_am=hidden_dim,
|
|
|
|
|
stride_bn=hidden_dim,
|
|
|
|
|
dp_attn_workaround_flag=dp_attn_workaround_flag,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return topk_weights, topk_ids
|
|
|
|
|
@@ -316,6 +344,7 @@ def fused_moe_router_shim(
|
|
|
|
|
topk,
|
|
|
|
|
renormalize,
|
|
|
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
|
|
|
enable_deterministic_inference: bool = False,
|
|
|
|
|
):
|
|
|
|
|
assert not renormalize
|
|
|
|
|
assert (
|
|
|
|
|
@@ -324,16 +353,22 @@ def fused_moe_router_shim(
|
|
|
|
|
)
|
|
|
|
|
bs, hidden_dim = hidden_states.shape
|
|
|
|
|
num_experts = gating_output.shape[0]
|
|
|
|
|
|
|
|
|
|
BLOCK_SIZE_M = 32
|
|
|
|
|
BLOCK_SIZE_N = 16
|
|
|
|
|
BLOCK_SIZE_K = 256
|
|
|
|
|
|
|
|
|
|
BLOCK_SIZE_N = max(num_experts, 16)
|
|
|
|
|
BLOCK_SIZE_K = (
|
|
|
|
|
256 if num_experts < 256 else 64
|
|
|
|
|
) # if experts are large, need to use smaller k block or shared memory OOM
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
bs >= 512
|
|
|
|
|
and topk <= 2
|
|
|
|
|
and num_experts <= BLOCK_SIZE_N
|
|
|
|
|
(bs >= 512 or num_experts > 8)
|
|
|
|
|
and hidden_dim % BLOCK_SIZE_K == 0
|
|
|
|
|
# we keep using single kernel to avoid non-deterministic behavior
|
|
|
|
|
and not enable_deterministic_inference
|
|
|
|
|
):
|
|
|
|
|
return fused_moe_router_large_bs_impl(
|
|
|
|
|
# if large batch size or large expert, use kernel that uses tensorcore in matmul
|
|
|
|
|
return fused_moe_router_tensorcore(
|
|
|
|
|
x=hidden_states,
|
|
|
|
|
router_weight=gating_output,
|
|
|
|
|
topk=topk,
|
|
|
|
|
@@ -341,9 +376,11 @@ def fused_moe_router_shim(
|
|
|
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
|
|
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
|
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
|
|
|
correction_bias=correction_bias,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return fused_moe_router_impl(
|
|
|
|
|
# if smaller, use kernel that does not use tensorcore in matmul
|
|
|
|
|
return fused_moe_router_cudacore(
|
|
|
|
|
x=hidden_states,
|
|
|
|
|
router_weight=gating_output,
|
|
|
|
|
topk=topk,
|
|
|
|
|
@@ -380,11 +417,10 @@ class FusedMoeRouter:
|
|
|
|
|
renormalize=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward_vllm(
|
|
|
|
|
def forward_torch(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
# g, _ = self.router_linear.forward(x)
|
|
|
|
|
g = x.float() @ self.router_linear.weight.T.float()
|
|
|
|
|
|
|
|
|
|
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
|
|
|
|
|