Add some fused elementwise kernels for grok-1 (#4398)
Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
342
python/sglang/srt/layers/moe/router.py
Normal file
342
python/sglang/srt/layers/moe/router.py
Normal file
@@ -0,0 +1,342 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.topk import fused_topk
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_router_kernel(
|
||||
input_ptr, # input (bs, hidden_dim)
|
||||
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
||||
topk_weights_ptr, # output (bs, topk)
|
||||
topk_ids_ptr, # output (bs, topk)
|
||||
num_experts: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
moe_softcapping: tl.constexpr,
|
||||
moe_renormalize: tl.constexpr, # not supported
|
||||
hidden_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < hidden_dim
|
||||
|
||||
# moe_router_weight is k major
|
||||
expert_offsets = tl.arange(0, num_experts)[:, None]
|
||||
router_mask = mask[None, :]
|
||||
w_router = tl.load(
|
||||
moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
|
||||
mask=router_mask,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
|
||||
|
||||
# todo: tl.dot?
|
||||
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
||||
|
||||
# logit softcap
|
||||
logits_scaled = logits / moe_softcapping
|
||||
exped = tl.exp(2 * logits_scaled)
|
||||
top = exped - 1
|
||||
bottom = exped + 1
|
||||
logits_softcapped = top / bottom * moe_softcapping
|
||||
|
||||
# topk
|
||||
# assert 1 <= topk <= num_experts
|
||||
|
||||
# 5.38 us
|
||||
|
||||
top1 = tl.argmax(logits_softcapped, axis=0)
|
||||
tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
|
||||
|
||||
top1_v = tl.max(logits_softcapped, axis=0)
|
||||
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
|
||||
|
||||
tl.store(
|
||||
topk_weights_ptr + pid * topk + 0,
|
||||
invsumexp,
|
||||
) # 5.73 us
|
||||
|
||||
if topk >= 2:
|
||||
top2 = tl.argmax(
|
||||
tl.where(
|
||||
tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
tl.store(topk_ids_ptr + pid * topk + 1, top2)
|
||||
top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
|
||||
tl.store(
|
||||
topk_weights_ptr + pid * topk + 1,
|
||||
tl.exp(top2_v - top1_v) * invsumexp,
|
||||
) # 5.95us
|
||||
|
||||
# probably slow
|
||||
if topk > 2:
|
||||
topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
|
||||
topk_mask = tl.where(
|
||||
tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
|
||||
)
|
||||
topk_mask = tl.where(
|
||||
tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
|
||||
)
|
||||
for i in range(2, topk):
|
||||
topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
|
||||
topk_mask = tl.where(
|
||||
tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
|
||||
)
|
||||
tl.store(topk_ids_ptr + pid * topk + i, topi)
|
||||
topi_v = tl.sum(
|
||||
logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
|
||||
)
|
||||
tl.store(
|
||||
topk_weights_ptr + pid * topk + i,
|
||||
tl.exp(topi_v - top1_v) * invsumexp,
|
||||
)
|
||||
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
||||
|
||||
|
||||
def fused_moe_router_impl(
|
||||
x: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
topk: int,
|
||||
moe_softcapping: float,
|
||||
):
|
||||
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
||||
bs, hidden_dim = x.shape
|
||||
num_experts = router_weight.shape[0]
|
||||
|
||||
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
||||
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||
|
||||
grid = lambda meta: (bs,)
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
||||
),
|
||||
}
|
||||
|
||||
fused_moe_router_kernel[grid](
|
||||
x,
|
||||
router_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_experts=num_experts,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
moe_renormalize=False,
|
||||
hidden_dim=hidden_dim,
|
||||
**config,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_router_large_bs_kernel(
|
||||
a_ptr, # input (bs, hidden_dim)
|
||||
b_ptr, # input (num_experts, hidden_dim)
|
||||
topk_weights_ptr, # output (bs, topk)
|
||||
topk_ids_ptr, # output (bs, topk)
|
||||
bs,
|
||||
num_experts: tl.constexpr,
|
||||
topk: tl.constexpr, # only support topk == 1
|
||||
moe_softcapping: tl.constexpr,
|
||||
moe_renormalize: tl.constexpr, # not supported
|
||||
K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
stride_am: tl.constexpr,
|
||||
stride_bn: tl.constexpr,
|
||||
):
|
||||
|
||||
# 1. get block id
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# 2. create pointers for the first block of A and B
|
||||
# 2.1. setup a_ptrs with offsets in m and k
|
||||
offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
|
||||
bs_mask = offs_m < bs
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
||||
a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
|
||||
|
||||
# 2.2. setup b_ptrs with offsets in k and n.
|
||||
# Note: b matrix is k-major.
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
|
||||
expert_mask = offs_n < num_experts
|
||||
b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
|
||||
|
||||
# 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
|
||||
# 3.1. iterate in K dimension
|
||||
# 3.2. transpose tile B
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=bs_mask,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
|
||||
acc += tl.dot(a, b)
|
||||
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += BLOCK_SIZE_K
|
||||
|
||||
# 4. logit softcap
|
||||
logits_scaled = acc / moe_softcapping
|
||||
exped = tl.exp(2 * logits_scaled)
|
||||
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
||||
|
||||
# 5. top1
|
||||
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
|
||||
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
|
||||
top1_v = tl.max(
|
||||
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
||||
)
|
||||
invsumexp = 1.0 / tl.sum(
|
||||
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
||||
)
|
||||
|
||||
# 6. store to output
|
||||
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
topk_mask = offs_topk < bs
|
||||
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
|
||||
tl.store(
|
||||
topk_weights_ptr + offs_topk,
|
||||
invsumexp,
|
||||
mask=topk_mask,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_router_large_bs_impl(
|
||||
x: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
topk: int,
|
||||
moe_softcapping: float,
|
||||
BLOCK_SIZE_M: int,
|
||||
BLOCK_SIZE_N: int,
|
||||
BLOCK_SIZE_K: int,
|
||||
):
|
||||
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
||||
bs, hidden_dim = x.shape
|
||||
num_experts = router_weight.shape[0]
|
||||
|
||||
assert num_experts <= BLOCK_SIZE_N
|
||||
assert hidden_dim % BLOCK_SIZE_K == 0
|
||||
assert topk == 1
|
||||
|
||||
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||
|
||||
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
||||
|
||||
fused_moe_router_large_bs_kernel[grid](
|
||||
a_ptr=x,
|
||||
b_ptr=router_weight,
|
||||
topk_weights_ptr=topk_weights,
|
||||
topk_ids_ptr=topk_ids,
|
||||
bs=bs,
|
||||
num_experts=num_experts,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
moe_renormalize=False,
|
||||
K=hidden_dim,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
stride_am=hidden_dim,
|
||||
stride_bn=hidden_dim,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_moe_router_shim(
|
||||
moe_softcapping,
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
):
|
||||
assert not renormalize
|
||||
assert (
|
||||
len(hidden_states.shape) == 2
|
||||
and hidden_states.shape[1] == gating_output.shape[1]
|
||||
)
|
||||
bs, hidden_dim = hidden_states.shape
|
||||
num_experts = gating_output.shape[0]
|
||||
BLOCK_SIZE_M = 32
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_K = 256
|
||||
if (
|
||||
bs >= 512
|
||||
and topk == 1
|
||||
and num_experts <= BLOCK_SIZE_N
|
||||
and hidden_dim % BLOCK_SIZE_K == 0
|
||||
):
|
||||
return fused_moe_router_large_bs_impl(
|
||||
x=hidden_states,
|
||||
router_weight=gating_output,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
)
|
||||
else:
|
||||
return fused_moe_router_impl(
|
||||
x=hidden_states,
|
||||
router_weight=gating_output,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
)
|
||||
|
||||
|
||||
class FusedMoeRouter:
|
||||
def __init__(self, router_linear, topk, moe_softcapping) -> None:
|
||||
self.router_linear = router_linear
|
||||
self.topk = topk
|
||||
self.moe_softcapping = moe_softcapping
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if x.is_cuda:
|
||||
return self.forward_cuda(x, residual)
|
||||
else:
|
||||
return self.forward_vllm(x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self, x: torch.Tensor, autotune=False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_moe_router_shim(
|
||||
moe_softcapping=self.moe_softcapping,
|
||||
hidden_states=x,
|
||||
gating_output=self.router_linear.weight,
|
||||
topk=self.topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
def forward_vllm(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# g, _ = self.router_linear.forward(x)
|
||||
g = x.float() @ self.router_linear.weight.T.float()
|
||||
|
||||
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
||||
|
||||
return fused_topk(x, g, self.topk, False)
|
||||
Reference in New Issue
Block a user