Add some fused elementwise kernels for grok-1 (#4398)
Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
411
python/sglang/srt/layers/elementwise.py
Normal file
411
python/sglang/srt/layers/elementwise.py
Normal file
@@ -0,0 +1,411 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
fused_softcap_autotune = triton.autotune(
|
||||
configs=[
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
|
||||
],
|
||||
key=["n_ele"],
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_softcap_kernel(
|
||||
output_ptr,
|
||||
input_ptr,
|
||||
n_ele,
|
||||
softcap_const: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_ele
|
||||
x = tl.load(input_ptr + offsets, mask=mask)
|
||||
fx = x.to(tl.float32)
|
||||
fxs = fx / softcap_const
|
||||
exped = tl.exp(2 * fxs)
|
||||
top = exped - 1
|
||||
bottom = exped + 1
|
||||
output = top / bottom * softcap_const
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
|
||||
|
||||
|
||||
def fused_softcap(x, softcap_const, autotune=False):
|
||||
output = torch.empty_like(x, dtype=torch.float32)
|
||||
n_elements = output.numel()
|
||||
if autotune:
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
|
||||
else:
|
||||
fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
|
||||
output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
# cast to float + softcap
|
||||
class Softcap:
|
||||
def __init__(self, softcap_const: float):
|
||||
self.softcap_const = softcap_const
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if x.is_cuda:
|
||||
return self.forward_cuda(x)
|
||||
else:
|
||||
return self.forward_native(x)
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
|
||||
return fused_softcap(x, self.softcap_const, autotune=autotune)
|
||||
|
||||
|
||||
rmsnorm_autotune = triton.autotune(
|
||||
configs=[
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
|
||||
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
|
||||
],
|
||||
key=["hidden_dim"],
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_dual_residual_rmsnorm_kernel(
|
||||
output_ptr,
|
||||
mid_ptr,
|
||||
activ_ptr,
|
||||
residual_ptr,
|
||||
weight1_ptr,
|
||||
weight2_ptr,
|
||||
eps: tl.constexpr,
|
||||
hidden_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
input_start = pid * hidden_dim
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < hidden_dim
|
||||
|
||||
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
|
||||
a = a_.to(tl.float32)
|
||||
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
|
||||
|
||||
r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
|
||||
w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
|
||||
w1 = w1_.to(tl.float32)
|
||||
|
||||
a2r = r + (a / rms * w1).to(r.dtype)
|
||||
tl.store(
|
||||
mid_ptr + input_start + offsets,
|
||||
a2r,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
a2r = a2r.to(tl.float32)
|
||||
rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
|
||||
|
||||
w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
|
||||
w2 = w2_.to(tl.float32)
|
||||
|
||||
tl.store(
|
||||
output_ptr + input_start + offsets,
|
||||
a2r / rms2 * w2, # implicitly casts to output dtype here
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
||||
fused_dual_residual_rmsnorm_kernel
|
||||
)
|
||||
|
||||
|
||||
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
||||
assert len(x.shape) == 2
|
||||
assert x.shape == residual.shape and x.dtype == residual.dtype
|
||||
output, mid = torch.empty_like(x), torch.empty_like(x)
|
||||
bs, hidden_dim = x.shape
|
||||
if autotune:
|
||||
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
||||
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
||||
)
|
||||
else:
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
||||
),
|
||||
}
|
||||
|
||||
fused_dual_residual_rmsnorm_kernel[(bs,)](
|
||||
output,
|
||||
mid,
|
||||
x,
|
||||
residual,
|
||||
weight1,
|
||||
weight2,
|
||||
eps=eps,
|
||||
hidden_dim=hidden_dim,
|
||||
**config,
|
||||
)
|
||||
|
||||
return output, mid
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_rmsnorm_kernel(
|
||||
output_ptr,
|
||||
activ_ptr,
|
||||
weight_ptr,
|
||||
eps: tl.constexpr,
|
||||
hidden_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
input_start = pid * hidden_dim
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < hidden_dim
|
||||
|
||||
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
|
||||
a = a_.to(tl.float32)
|
||||
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
|
||||
|
||||
w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
|
||||
w1 = w1_.to(tl.float32)
|
||||
|
||||
a_rms = a / rms * w1
|
||||
|
||||
tl.store(
|
||||
output_ptr + input_start + offsets,
|
||||
a_rms, # implicitly casts to output dtype here
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
||||
assert len(x.shape) == 2
|
||||
if inplace:
|
||||
output = x
|
||||
else:
|
||||
output = torch.empty_like(x)
|
||||
bs, hidden_dim = x.shape
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
||||
),
|
||||
}
|
||||
|
||||
fused_rmsnorm_kernel[(bs,)](
|
||||
output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class FusedDualResidualRMSNorm:
|
||||
"""
|
||||
Fused implementation of
|
||||
y = RMSNorm2(RMSNorm1(x) + residual))
|
||||
"""
|
||||
|
||||
def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
|
||||
self.rmsnorm1 = rmsnorm1
|
||||
self.rmsnorm2 = rmsnorm2
|
||||
self.variance_epsilon = self.rmsnorm1.variance_epsilon
|
||||
assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
|
||||
assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if x.is_cuda:
|
||||
return self.forward_cuda(x, residual)
|
||||
else:
|
||||
return self.forward_flashinfer(x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self, x: torch.Tensor, residual: torch.Tensor, autotune=False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_dual_residual_rmsnorm(
|
||||
x,
|
||||
residual,
|
||||
self.rmsnorm1.weight,
|
||||
self.rmsnorm2.weight,
|
||||
self.variance_epsilon,
|
||||
autotune=autotune,
|
||||
)
|
||||
|
||||
def forward_flashinfer(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
normed1 = self.rmsnorm1(x)
|
||||
residual = normed1 + residual
|
||||
return self.rmsnorm2(residual), residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
normed1 = self.rmsnorm1.forward_native(x)
|
||||
residual = normed1 + residual
|
||||
return self.rmsnorm2.forward_native(residual), residual
|
||||
|
||||
|
||||
# gelu on first half of vector
|
||||
@triton.jit
|
||||
def gelu_and_mul_kernel(
|
||||
out_hidden_states_ptr, # (bs, hidden_dim)
|
||||
out_scales_ptr, # (bs,)
|
||||
hidden_states_ptr, # (bs, hidden_dim * 2)
|
||||
quant_max: tl.constexpr,
|
||||
static_scale: tl.constexpr,
|
||||
hidden_dim: tl.constexpr, # the output hidden_dim
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
input_start = pid * hidden_dim * 2
|
||||
output_start = pid * hidden_dim
|
||||
|
||||
input1_offs = tl.arange(0, BLOCK_SIZE)
|
||||
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
|
||||
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
|
||||
output_offs = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
x1 = tl.load(
|
||||
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
|
||||
).to(tl.float32)
|
||||
x3 = tl.load(
|
||||
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
|
||||
).to(tl.float32)
|
||||
|
||||
# gelu
|
||||
# cast down before mul to better match training?
|
||||
gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
|
||||
out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
|
||||
|
||||
if quant_max is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
|
||||
|
||||
|
||||
def gelu_and_mul_triton(
|
||||
hidden_states,
|
||||
scales=None,
|
||||
quantize=None, # dtype to quantize to
|
||||
out=None,
|
||||
):
|
||||
bs, in_hidden_dim = hidden_states.shape
|
||||
hidden_dim = in_hidden_dim // 2
|
||||
|
||||
if out is None:
|
||||
out_hidden_states = torch.empty(
|
||||
(bs, hidden_dim),
|
||||
dtype=quantize or hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
else:
|
||||
assert out.shape == (bs, hidden_dim)
|
||||
assert out.dtype == (quantize or hidden_states.dtype)
|
||||
out_hidden_states = out
|
||||
out_scales = None
|
||||
static_scale = False
|
||||
if quantize is not None:
|
||||
if scales is None:
|
||||
out_scales = torch.empty(
|
||||
(bs,), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
else:
|
||||
out_scales = scales
|
||||
static_scale = True
|
||||
|
||||
config = {
|
||||
# 8 ele per thread (not tuned)
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
|
||||
),
|
||||
}
|
||||
|
||||
gelu_and_mul_kernel[(bs,)](
|
||||
out_hidden_states,
|
||||
out_scales,
|
||||
hidden_states,
|
||||
quant_max=torch.finfo(quantize).max if quantize is not None else None,
|
||||
static_scale=static_scale,
|
||||
hidden_dim=hidden_dim,
|
||||
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
|
||||
**config,
|
||||
)
|
||||
|
||||
if quantize is not None:
|
||||
return out_hidden_states, out_scales
|
||||
else:
|
||||
return out_hidden_states, None
|
||||
342
python/sglang/srt/layers/moe/router.py
Normal file
342
python/sglang/srt/layers/moe/router.py
Normal file
@@ -0,0 +1,342 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.topk import fused_topk
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_router_kernel(
|
||||
input_ptr, # input (bs, hidden_dim)
|
||||
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
||||
topk_weights_ptr, # output (bs, topk)
|
||||
topk_ids_ptr, # output (bs, topk)
|
||||
num_experts: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
moe_softcapping: tl.constexpr,
|
||||
moe_renormalize: tl.constexpr, # not supported
|
||||
hidden_dim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < hidden_dim
|
||||
|
||||
# moe_router_weight is k major
|
||||
expert_offsets = tl.arange(0, num_experts)[:, None]
|
||||
router_mask = mask[None, :]
|
||||
w_router = tl.load(
|
||||
moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
|
||||
mask=router_mask,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
|
||||
|
||||
# todo: tl.dot?
|
||||
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
||||
|
||||
# logit softcap
|
||||
logits_scaled = logits / moe_softcapping
|
||||
exped = tl.exp(2 * logits_scaled)
|
||||
top = exped - 1
|
||||
bottom = exped + 1
|
||||
logits_softcapped = top / bottom * moe_softcapping
|
||||
|
||||
# topk
|
||||
# assert 1 <= topk <= num_experts
|
||||
|
||||
# 5.38 us
|
||||
|
||||
top1 = tl.argmax(logits_softcapped, axis=0)
|
||||
tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
|
||||
|
||||
top1_v = tl.max(logits_softcapped, axis=0)
|
||||
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
|
||||
|
||||
tl.store(
|
||||
topk_weights_ptr + pid * topk + 0,
|
||||
invsumexp,
|
||||
) # 5.73 us
|
||||
|
||||
if topk >= 2:
|
||||
top2 = tl.argmax(
|
||||
tl.where(
|
||||
tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
tl.store(topk_ids_ptr + pid * topk + 1, top2)
|
||||
top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
|
||||
tl.store(
|
||||
topk_weights_ptr + pid * topk + 1,
|
||||
tl.exp(top2_v - top1_v) * invsumexp,
|
||||
) # 5.95us
|
||||
|
||||
# probably slow
|
||||
if topk > 2:
|
||||
topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
|
||||
topk_mask = tl.where(
|
||||
tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
|
||||
)
|
||||
topk_mask = tl.where(
|
||||
tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
|
||||
)
|
||||
for i in range(2, topk):
|
||||
topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
|
||||
topk_mask = tl.where(
|
||||
tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
|
||||
)
|
||||
tl.store(topk_ids_ptr + pid * topk + i, topi)
|
||||
topi_v = tl.sum(
|
||||
logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
|
||||
)
|
||||
tl.store(
|
||||
topk_weights_ptr + pid * topk + i,
|
||||
tl.exp(topi_v - top1_v) * invsumexp,
|
||||
)
|
||||
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
||||
|
||||
|
||||
def fused_moe_router_impl(
|
||||
x: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
topk: int,
|
||||
moe_softcapping: float,
|
||||
):
|
||||
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
||||
bs, hidden_dim = x.shape
|
||||
num_experts = router_weight.shape[0]
|
||||
|
||||
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
||||
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||
|
||||
grid = lambda meta: (bs,)
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
||||
),
|
||||
}
|
||||
|
||||
fused_moe_router_kernel[grid](
|
||||
x,
|
||||
router_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_experts=num_experts,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
moe_renormalize=False,
|
||||
hidden_dim=hidden_dim,
|
||||
**config,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_router_large_bs_kernel(
|
||||
a_ptr, # input (bs, hidden_dim)
|
||||
b_ptr, # input (num_experts, hidden_dim)
|
||||
topk_weights_ptr, # output (bs, topk)
|
||||
topk_ids_ptr, # output (bs, topk)
|
||||
bs,
|
||||
num_experts: tl.constexpr,
|
||||
topk: tl.constexpr, # only support topk == 1
|
||||
moe_softcapping: tl.constexpr,
|
||||
moe_renormalize: tl.constexpr, # not supported
|
||||
K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
stride_am: tl.constexpr,
|
||||
stride_bn: tl.constexpr,
|
||||
):
|
||||
|
||||
# 1. get block id
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
# 2. create pointers for the first block of A and B
|
||||
# 2.1. setup a_ptrs with offsets in m and k
|
||||
offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
|
||||
bs_mask = offs_m < bs
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
||||
a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
|
||||
|
||||
# 2.2. setup b_ptrs with offsets in k and n.
|
||||
# Note: b matrix is k-major.
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
|
||||
expert_mask = offs_n < num_experts
|
||||
b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
|
||||
|
||||
# 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
|
||||
# 3.1. iterate in K dimension
|
||||
# 3.2. transpose tile B
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=bs_mask,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
|
||||
acc += tl.dot(a, b)
|
||||
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += BLOCK_SIZE_K
|
||||
|
||||
# 4. logit softcap
|
||||
logits_scaled = acc / moe_softcapping
|
||||
exped = tl.exp(2 * logits_scaled)
|
||||
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
||||
|
||||
# 5. top1
|
||||
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
|
||||
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
|
||||
top1_v = tl.max(
|
||||
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
||||
)
|
||||
invsumexp = 1.0 / tl.sum(
|
||||
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
||||
)
|
||||
|
||||
# 6. store to output
|
||||
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
topk_mask = offs_topk < bs
|
||||
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
|
||||
tl.store(
|
||||
topk_weights_ptr + offs_topk,
|
||||
invsumexp,
|
||||
mask=topk_mask,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe_router_large_bs_impl(
|
||||
x: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
topk: int,
|
||||
moe_softcapping: float,
|
||||
BLOCK_SIZE_M: int,
|
||||
BLOCK_SIZE_N: int,
|
||||
BLOCK_SIZE_K: int,
|
||||
):
|
||||
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
||||
bs, hidden_dim = x.shape
|
||||
num_experts = router_weight.shape[0]
|
||||
|
||||
assert num_experts <= BLOCK_SIZE_N
|
||||
assert hidden_dim % BLOCK_SIZE_K == 0
|
||||
assert topk == 1
|
||||
|
||||
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||
|
||||
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
||||
|
||||
fused_moe_router_large_bs_kernel[grid](
|
||||
a_ptr=x,
|
||||
b_ptr=router_weight,
|
||||
topk_weights_ptr=topk_weights,
|
||||
topk_ids_ptr=topk_ids,
|
||||
bs=bs,
|
||||
num_experts=num_experts,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
moe_renormalize=False,
|
||||
K=hidden_dim,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
stride_am=hidden_dim,
|
||||
stride_bn=hidden_dim,
|
||||
)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_moe_router_shim(
|
||||
moe_softcapping,
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
):
|
||||
assert not renormalize
|
||||
assert (
|
||||
len(hidden_states.shape) == 2
|
||||
and hidden_states.shape[1] == gating_output.shape[1]
|
||||
)
|
||||
bs, hidden_dim = hidden_states.shape
|
||||
num_experts = gating_output.shape[0]
|
||||
BLOCK_SIZE_M = 32
|
||||
BLOCK_SIZE_N = 16
|
||||
BLOCK_SIZE_K = 256
|
||||
if (
|
||||
bs >= 512
|
||||
and topk == 1
|
||||
and num_experts <= BLOCK_SIZE_N
|
||||
and hidden_dim % BLOCK_SIZE_K == 0
|
||||
):
|
||||
return fused_moe_router_large_bs_impl(
|
||||
x=hidden_states,
|
||||
router_weight=gating_output,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
)
|
||||
else:
|
||||
return fused_moe_router_impl(
|
||||
x=hidden_states,
|
||||
router_weight=gating_output,
|
||||
topk=topk,
|
||||
moe_softcapping=moe_softcapping,
|
||||
)
|
||||
|
||||
|
||||
class FusedMoeRouter:
|
||||
def __init__(self, router_linear, topk, moe_softcapping) -> None:
|
||||
self.router_linear = router_linear
|
||||
self.topk = topk
|
||||
self.moe_softcapping = moe_softcapping
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, residual: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if x.is_cuda:
|
||||
return self.forward_cuda(x, residual)
|
||||
else:
|
||||
return self.forward_vllm(x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self, x: torch.Tensor, autotune=False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return fused_moe_router_shim(
|
||||
moe_softcapping=self.moe_softcapping,
|
||||
hidden_states=x,
|
||||
gating_output=self.router_linear.weight,
|
||||
topk=self.topk,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
def forward_vllm(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# g, _ = self.router_linear.forward(x)
|
||||
g = x.float() @ self.router_linear.weight.T.float()
|
||||
|
||||
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
||||
|
||||
return fused_topk(x, g, self.topk, False)
|
||||
@@ -15,28 +15,36 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
||||
"""Inference-only Grok1 model."""
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -44,47 +52,17 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.loader import DefaultModelLoader
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
from sglang.srt.utils import dump_to_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Grok1MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
reduce_results=reduce_results,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.act_fn = GeluAndMul(approximate="tanh")
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
debug_tensor_dump_output_folder = None
|
||||
debug_tensor_dump_inject = False
|
||||
|
||||
|
||||
class Grok1MoE(nn.Module):
|
||||
@@ -108,51 +86,55 @@ class Grok1MoE(nn.Module):
|
||||
tp_size: Optional[int] = None,
|
||||
reduce_results=True,
|
||||
use_presharded_weights: bool = False,
|
||||
prefix: str = "",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
# Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961)
|
||||
self.gate = ReplicatedLinear(
|
||||
hidden_size,
|
||||
num_experts,
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
self.router_logit_softcapping = getattr(
|
||||
config, "router_logit_softcapping", 30.0
|
||||
)
|
||||
self.experts = FusedMoE(
|
||||
custom_routing_function = functools.partial(
|
||||
fused_moe_router_shim, self.router_logit_softcapping
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if global_server_args_dict["enable_ep_moe"]:
|
||||
MoEImpl = EPMoE
|
||||
else:
|
||||
MoEImpl = FusedMoE
|
||||
kwargs["reduce_results"] = reduce_results
|
||||
kwargs["use_presharded_weights"] = use_presharded_weights
|
||||
kwargs["inplace"] = inplace
|
||||
kwargs["no_combine"] = no_combine
|
||||
|
||||
self.experts = MoEImpl(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=reduce_results,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
custom_routing_function=custom_routing_function,
|
||||
activation="gelu",
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
router_logits = 30.0 * F.tanh(router_logits / 30.0)
|
||||
|
||||
# need to assert self.gate.quant_method is unquantized
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
return self.experts(hidden_states, self.gate.weight)
|
||||
|
||||
|
||||
class Grok1Attention(nn.Module):
|
||||
@@ -167,31 +149,33 @@ class Grok1Attention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
load_presharded_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
attn_tp_rank = get_tensor_model_parallel_rank()
|
||||
attn_tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
assert self.total_num_heads % attn_tp_size == 0
|
||||
self.num_heads = self.total_num_heads // attn_tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
if self.total_num_kv_heads >= attn_tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
assert self.total_num_kv_heads % attn_tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
assert attn_tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
||||
self.head_dim = getattr(config, "head_dim", 128)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.load_presharded_attn = load_presharded_attn
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -200,7 +184,9 @@ class Grok1Attention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
load_presharded_attn=self.load_presharded_attn,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
@@ -208,7 +194,9 @@ class Grok1Attention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
use_presharded_weights=self.load_presharded_attn,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@@ -227,7 +215,6 @@ class Grok1Attention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
logit_cap=logit_cap,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -236,10 +223,73 @@ class Grok1Attention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states
|
||||
if debug_tensor_dump_output_folder:
|
||||
dump_to_file(
|
||||
debug_tensor_dump_output_folder,
|
||||
f"attn_input_{self.layer_id}",
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
if debug_tensor_dump_inject:
|
||||
name = os.path.join(
|
||||
debug_tensor_dump_output_folder,
|
||||
f"jax_dump_attn_input_{self.layer_id}.npy",
|
||||
)
|
||||
logger.info(f"Load {name} from jax.")
|
||||
x = np.load(name)
|
||||
hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to(
|
||||
hidden_states
|
||||
)
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
if debug_tensor_dump_output_folder:
|
||||
num_tokens = q.shape[0]
|
||||
num_heads_q = self.num_heads
|
||||
head_dim = self.head_dim
|
||||
num_heads_kv = k.numel() // (num_tokens * head_dim)
|
||||
|
||||
dump_to_file(
|
||||
debug_tensor_dump_output_folder,
|
||||
f"q_{self.layer_id}",
|
||||
tensor_model_parallel_all_gather(
|
||||
q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1
|
||||
).contiguous(),
|
||||
)
|
||||
dump_to_file(
|
||||
debug_tensor_dump_output_folder,
|
||||
f"k_{self.layer_id}",
|
||||
tensor_model_parallel_all_gather(
|
||||
k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
|
||||
).contiguous(),
|
||||
)
|
||||
dump_to_file(
|
||||
debug_tensor_dump_output_folder,
|
||||
f"v_{self.layer_id}",
|
||||
tensor_model_parallel_all_gather(
|
||||
v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
|
||||
).contiguous(),
|
||||
)
|
||||
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
|
||||
if debug_tensor_dump_output_folder:
|
||||
dump_to_file(
|
||||
debug_tensor_dump_output_folder,
|
||||
f"attn_output_{self.layer_id}",
|
||||
tensor_model_parallel_all_gather(
|
||||
attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(),
|
||||
dim=1,
|
||||
).contiguous(),
|
||||
)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -250,8 +300,9 @@ class Grok1DecoderLayer(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
prefix: str = "",
|
||||
load_presharded_moe: bool = False,
|
||||
load_presharded_attn: bool = False,
|
||||
load_presharded_mlp: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
@@ -268,7 +319,8 @@ class Grok1DecoderLayer(nn.Module):
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
reduce_results=False,
|
||||
load_presharded_attn=load_presharded_attn,
|
||||
)
|
||||
self.block_sparse_moe = Grok1MoE(
|
||||
config=config,
|
||||
@@ -282,38 +334,68 @@ class Grok1DecoderLayer(nn.Module):
|
||||
),
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
prefix=add_prefix("block_sparse_moe", prefix),
|
||||
use_presharded_weights=load_presharded_moe,
|
||||
inplace=True,
|
||||
no_combine=False, # just a suggestion to not combine topk
|
||||
)
|
||||
|
||||
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.ffn = self.block_sparse_moe
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
deferred_norm: Optional[RMSNorm] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
|
||||
# Self Attention
|
||||
hidden_states = (
|
||||
self.post_attn_norm(
|
||||
self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=self.pre_attn_norm(hidden_states),
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if deferred_norm is not None:
|
||||
assert residual is not None
|
||||
# here hidden_states is output of ffn, residual is residual from after previous attn layer
|
||||
hidden_states, residual = fused_dual_residual_rmsnorm(
|
||||
hidden_states,
|
||||
residual,
|
||||
deferred_norm.weight,
|
||||
self.pre_attn_norm.weight,
|
||||
deferred_norm.variance_epsilon,
|
||||
)
|
||||
+ hidden_states
|
||||
else:
|
||||
# here hidden_states is the residual
|
||||
hidden_states, residual = (
|
||||
fused_rmsnorm(
|
||||
hidden_states,
|
||||
self.pre_attn_norm.weight,
|
||||
self.pre_attn_norm.variance_epsilon,
|
||||
),
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
hidden_states, residual = fused_dual_residual_rmsnorm(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.post_attn_norm.weight,
|
||||
self.pre_moe_norm.weight,
|
||||
self.post_attn_norm.variance_epsilon,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = (
|
||||
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
|
||||
+ hidden_states
|
||||
)
|
||||
return hidden_states
|
||||
hidden_states = self.ffn(hidden_states)
|
||||
return hidden_states, residual, self.post_moe_norm # defer layernorm
|
||||
|
||||
|
||||
class Grok1Model(nn.Module):
|
||||
@@ -321,8 +403,10 @@ class Grok1Model(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
prefix: str = "",
|
||||
load_presharded_moe: bool = False,
|
||||
load_presharded_embedding: bool = False,
|
||||
load_presharded_attn: bool = False,
|
||||
load_presharded_mlp: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -332,7 +416,7 @@ class Grok1Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
use_presharded_weights=load_presharded_embedding,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
@@ -340,8 +424,9 @@ class Grok1Model(nn.Module):
|
||||
config,
|
||||
i,
|
||||
quant_config=quant_config,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
prefix=add_prefix(f"layers.{i}", prefix),
|
||||
load_presharded_moe=load_presharded_moe,
|
||||
load_presharded_attn=load_presharded_attn,
|
||||
load_presharded_mlp=load_presharded_mlp,
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -361,10 +446,48 @@ class Grok1Model(nn.Module):
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
|
||||
residual, deferred_norm = None, None
|
||||
for i in range(len(self.layers)):
|
||||
hidden_states = self.layers[i](positions, hidden_states, forward_batch)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states.mul_(self.config.output_multiplier_scale)
|
||||
hidden_states, residual, deferred_norm = self.layers[i](
|
||||
positions, hidden_states, forward_batch, residual, deferred_norm
|
||||
)
|
||||
|
||||
if debug_tensor_dump_output_folder:
|
||||
hidden_states = (
|
||||
fused_rmsnorm(
|
||||
hidden_states,
|
||||
deferred_norm.weight,
|
||||
deferred_norm.variance_epsilon,
|
||||
)
|
||||
+ residual
|
||||
)
|
||||
|
||||
dump_to_file(
|
||||
debug_tensor_dump_output_folder,
|
||||
"last_hidden_before_norm",
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
hidden_states = fused_rmsnorm(
|
||||
hidden_states,
|
||||
self.norm.weight,
|
||||
self.norm.variance_epsilon,
|
||||
)
|
||||
|
||||
dump_to_file(
|
||||
debug_tensor_dump_output_folder,
|
||||
"last_hidden_after_norm",
|
||||
hidden_states,
|
||||
)
|
||||
else:
|
||||
hidden_states, _ = fused_dual_residual_rmsnorm(
|
||||
hidden_states,
|
||||
residual,
|
||||
deferred_norm.weight,
|
||||
self.norm.weight,
|
||||
deferred_norm.variance_epsilon,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -373,31 +496,77 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
if (
|
||||
# Get presharded weights.
|
||||
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
|
||||
self.load_presharded_moe = (
|
||||
self.config.num_local_experts > 0
|
||||
and get_tensor_model_parallel_world_size() > 1
|
||||
):
|
||||
self.use_presharded_weights = True
|
||||
)
|
||||
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
|
||||
self.load_presharded_embedding = getattr(
|
||||
config, "load_presharded_embedding", False
|
||||
)
|
||||
|
||||
self.is_weights_presharded = (
|
||||
self.load_presharded_mlp
|
||||
or self.load_presharded_moe
|
||||
or self.load_presharded_attn
|
||||
or self.load_presharded_embedding
|
||||
)
|
||||
|
||||
if self.is_weights_presharded:
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
else:
|
||||
self.use_presharded_weights = False
|
||||
|
||||
default_replicate_lm_head = False
|
||||
self.replicate_lm_head = getattr(
|
||||
config, "replicate_lm_head", default_replicate_lm_head
|
||||
)
|
||||
|
||||
self.model = Grok1Model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
prefix=add_prefix("model", prefix),
|
||||
load_presharded_moe=self.load_presharded_moe,
|
||||
load_presharded_embedding=self.load_presharded_embedding,
|
||||
load_presharded_attn=self.load_presharded_attn,
|
||||
load_presharded_mlp=self.load_presharded_mlp,
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
lm_head_params_dtype = None
|
||||
if self.replicate_lm_head:
|
||||
self.lm_head = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
params_dtype=lm_head_params_dtype,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
use_presharded_weights=self.load_presharded_embedding,
|
||||
params_dtype=lm_head_params_dtype,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
# Dump tensors for debugging
|
||||
global debug_tensor_dump_output_folder, debug_tensor_dump_inject
|
||||
debug_tensor_dump_output_folder = global_server_args_dict[
|
||||
"debug_tensor_dump_output_folder"
|
||||
]
|
||||
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
logger.info(
|
||||
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
|
||||
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -406,6 +575,9 @@ class Grok1ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if debug_tensor_dump_output_folder:
|
||||
dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids)
|
||||
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
@@ -414,21 +586,28 @@ class Grok1ForCausalLM(nn.Module):
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
):
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
stacked_params_mapping = [
|
||||
num_experts: Optional[int] = None,
|
||||
ignore_parent_name: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if num_experts is None:
|
||||
num_experts = self.config.num_local_experts
|
||||
stacked_params_mapping = []
|
||||
stacked_params_mapping += [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
stacked_params_mapping += [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
@@ -439,14 +618,25 @@ class Grok1ForCausalLM(nn.Module):
|
||||
all_names = set(params_dict.keys())
|
||||
hit_names = set()
|
||||
|
||||
def load_weight_wrapper(name, loaded_weight, *args, **kwargs):
|
||||
def load_weight_wrapper(
|
||||
name: str, loaded_weight: torch.Tensor, *args, **kwargs
|
||||
):
|
||||
if ignore_parent_name:
|
||||
name = name.split(".")[-1]
|
||||
|
||||
if name not in params_dict:
|
||||
return
|
||||
|
||||
# Fuse constant multipliers into the weights
|
||||
if "lm_head" in name:
|
||||
loaded_weight = (
|
||||
loaded_weight.to(torch.float32)
|
||||
* self.config.output_multiplier_scale
|
||||
)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight, *args, **kwargs)
|
||||
|
||||
hit_names.add(name)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
@@ -460,7 +650,6 @@ class Grok1ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
load_weight_wrapper(name, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
@@ -487,13 +676,79 @@ class Grok1ForCausalLM(nn.Module):
|
||||
|
||||
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
||||
|
||||
if len(hit_names) > 5:
|
||||
missing = all_names - hit_names
|
||||
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
||||
logger.info(
|
||||
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
|
||||
)
|
||||
if len(missing_exclude_scales) > 0:
|
||||
raise ValueError(
|
||||
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
||||
)
|
||||
|
||||
elif len(hit_names) == 0:
|
||||
raise ValueError("load_weights failed because it did not hit any names.")
|
||||
|
||||
return hit_names
|
||||
|
||||
def get_num_params_analytical(self):
|
||||
cfg = self.config
|
||||
moe_intermediate_size = getattr(
|
||||
cfg,
|
||||
"moe_intermediate_size",
|
||||
getattr(cfg, "intermediate_size", None),
|
||||
)
|
||||
num_experts = cfg.num_local_experts
|
||||
|
||||
wq = (
|
||||
cfg.num_hidden_layers
|
||||
* cfg.hidden_size
|
||||
* cfg.num_attention_heads
|
||||
* cfg.head_dim
|
||||
)
|
||||
wkv = (
|
||||
cfg.num_hidden_layers
|
||||
* cfg.hidden_size
|
||||
* cfg.num_key_value_heads
|
||||
* cfg.head_dim
|
||||
* 2
|
||||
)
|
||||
out = (
|
||||
cfg.num_hidden_layers
|
||||
* cfg.hidden_size
|
||||
* cfg.num_attention_heads
|
||||
* cfg.head_dim
|
||||
)
|
||||
ffn1 = (
|
||||
cfg.num_hidden_layers
|
||||
* num_experts
|
||||
* cfg.hidden_size
|
||||
* moe_intermediate_size
|
||||
* 2
|
||||
)
|
||||
ffn2 = (
|
||||
cfg.num_hidden_layers
|
||||
* num_experts
|
||||
* cfg.hidden_size
|
||||
* moe_intermediate_size
|
||||
)
|
||||
embed = cfg.hidden_size * cfg.vocab_size * 2
|
||||
return wq + wkv + out + ffn1 + ffn2 + embed
|
||||
|
||||
def get_num_params_torch(self):
|
||||
return (
|
||||
sum(p.numel() for p in self.parameters())
|
||||
* get_tensor_model_parallel_world_size()
|
||||
)
|
||||
|
||||
|
||||
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
||||
|
||||
|
||||
def _prepare_presharded_weights(
|
||||
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
) -> Tuple[str, list[str], bool]:
|
||||
import glob
|
||||
import os
|
||||
|
||||
@@ -522,7 +777,7 @@ def _prepare_presharded_weights(
|
||||
# The new format
|
||||
allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"]
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
hf_weights_files = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user