diff --git a/python/sglang/srt/layers/elementwise.py b/python/sglang/srt/layers/elementwise.py new file mode 100644 index 000000000..931dd2b9f --- /dev/null +++ b/python/sglang/srt/layers/elementwise.py @@ -0,0 +1,411 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + +fused_softcap_autotune = triton.autotune( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), + triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16), + triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4), + triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4), + triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32), + triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32), + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32), + triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32), + ], + key=["n_ele"], +) + + +@triton.jit +def fused_softcap_kernel( + output_ptr, + input_ptr, + n_ele, + softcap_const: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_ele + x = tl.load(input_ptr + offsets, mask=mask) + fx = x.to(tl.float32) + fxs = fx / softcap_const + exped = tl.exp(2 * fxs) + top = exped - 1 + bottom = exped + 1 + output = top / bottom * softcap_const + tl.store(output_ptr + offsets, output, mask=mask) + + +fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel) + + +def fused_softcap(x, softcap_const, autotune=False): + output = torch.empty_like(x, dtype=torch.float32) + n_elements = output.numel() + if autotune: + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const) + else: + fused_softcap_kernel[(triton.cdiv(n_elements, 128),)]( + output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8 + ) + return output + + +# cast to float + softcap +class Softcap: + def __init__(self, softcap_const: float): + self.softcap_const = softcap_const + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.is_cuda: + return self.forward_cuda(x) + else: + return self.forward_native(x) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + return torch.tanh(x.float() / self.softcap_const) * self.softcap_const + + def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor: + return fused_softcap(x, self.softcap_const, autotune=autotune) + + +rmsnorm_autotune = triton.autotune( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8), + triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8), + triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16), + triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4), + triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4), + ], + key=["hidden_dim"], +) + + +@triton.jit +def fused_dual_residual_rmsnorm_kernel( + output_ptr, + mid_ptr, + activ_ptr, + residual_ptr, + weight1_ptr, + weight2_ptr, + eps: tl.constexpr, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + input_start = pid * hidden_dim + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_dim + + a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0) + a = a_.to(tl.float32) + rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps) + + r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0) + w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0) + w1 = w1_.to(tl.float32) + + a2r = r + (a / rms * w1).to(r.dtype) + tl.store( + mid_ptr + input_start + offsets, + a2r, + mask=mask, + ) + + a2r = a2r.to(tl.float32) + rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps) + + w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0) + w2 = w2_.to(tl.float32) + + tl.store( + output_ptr + input_start + offsets, + a2r / rms2 * w2, # implicitly casts to output dtype here + mask=mask, + ) + + +fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune( + fused_dual_residual_rmsnorm_kernel +) + + +def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False): + assert len(x.shape) == 2 + assert x.shape == residual.shape and x.dtype == residual.dtype + output, mid = torch.empty_like(x), torch.empty_like(x) + bs, hidden_dim = x.shape + if autotune: + fused_dual_residual_rmsnorm_kernel_autotune[(bs,)]( + output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim + ) + else: + config = { + "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), + "num_warps": max( + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 + ), + } + + fused_dual_residual_rmsnorm_kernel[(bs,)]( + output, + mid, + x, + residual, + weight1, + weight2, + eps=eps, + hidden_dim=hidden_dim, + **config, + ) + + return output, mid + + +@triton.jit +def fused_rmsnorm_kernel( + output_ptr, + activ_ptr, + weight_ptr, + eps: tl.constexpr, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + input_start = pid * hidden_dim + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_dim + + a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0) + a = a_.to(tl.float32) + rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps) + + w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0) + w1 = w1_.to(tl.float32) + + a_rms = a / rms * w1 + + tl.store( + output_ptr + input_start + offsets, + a_rms, # implicitly casts to output dtype here + mask=mask, + ) + + +def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False): + assert len(x.shape) == 2 + if inplace: + output = x + else: + output = torch.empty_like(x) + bs, hidden_dim = x.shape + config = { + "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), + "num_warps": max( + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 + ), + } + + fused_rmsnorm_kernel[(bs,)]( + output, x, weight, eps=eps, hidden_dim=hidden_dim, **config + ) + return output + + +class FusedDualResidualRMSNorm: + """ + Fused implementation of + y = RMSNorm2(RMSNorm1(x) + residual)) + """ + + def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1 + self.rmsnorm1 = rmsnorm1 + self.rmsnorm2 = rmsnorm2 + self.variance_epsilon = self.rmsnorm1.variance_epsilon + assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon + assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward( + self, x: torch.Tensor, residual: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if x.is_cuda: + return self.forward_cuda(x, residual) + else: + return self.forward_flashinfer(x, residual) + + def forward_cuda( + self, x: torch.Tensor, residual: torch.Tensor, autotune=False + ) -> Tuple[torch.Tensor, torch.Tensor]: + return fused_dual_residual_rmsnorm( + x, + residual, + self.rmsnorm1.weight, + self.rmsnorm2.weight, + self.variance_epsilon, + autotune=autotune, + ) + + def forward_flashinfer( + self, + x: torch.Tensor, + residual: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + normed1 = self.rmsnorm1(x) + residual = normed1 + residual + return self.rmsnorm2(residual), residual + + def forward_native( + self, + x: torch.Tensor, + residual: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + normed1 = self.rmsnorm1.forward_native(x) + residual = normed1 + residual + return self.rmsnorm2.forward_native(residual), residual + + +# gelu on first half of vector +@triton.jit +def gelu_and_mul_kernel( + out_hidden_states_ptr, # (bs, hidden_dim) + out_scales_ptr, # (bs,) + hidden_states_ptr, # (bs, hidden_dim * 2) + quant_max: tl.constexpr, + static_scale: tl.constexpr, + hidden_dim: tl.constexpr, # the output hidden_dim + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + input_start = pid * hidden_dim * 2 + output_start = pid * hidden_dim + + input1_offs = tl.arange(0, BLOCK_SIZE) + mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output + input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE) + output_offs = tl.arange(0, BLOCK_SIZE) + + x1 = tl.load( + hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0 + ).to(tl.float32) + x3 = tl.load( + hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0 + ).to(tl.float32) + + # gelu + # cast down before mul to better match training? + gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1 + out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty) + + if quant_max is not None: + raise NotImplementedError() + + tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask) + + +def gelu_and_mul_triton( + hidden_states, + scales=None, + quantize=None, # dtype to quantize to + out=None, +): + bs, in_hidden_dim = hidden_states.shape + hidden_dim = in_hidden_dim // 2 + + if out is None: + out_hidden_states = torch.empty( + (bs, hidden_dim), + dtype=quantize or hidden_states.dtype, + device=hidden_states.device, + ) + else: + assert out.shape == (bs, hidden_dim) + assert out.dtype == (quantize or hidden_states.dtype) + out_hidden_states = out + out_scales = None + static_scale = False + if quantize is not None: + if scales is None: + out_scales = torch.empty( + (bs,), dtype=torch.float32, device=hidden_states.device + ) + else: + out_scales = scales + static_scale = True + + config = { + # 8 ele per thread (not tuned) + "num_warps": max( + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4 + ), + } + + gelu_and_mul_kernel[(bs,)]( + out_hidden_states, + out_scales, + hidden_states, + quant_max=torch.finfo(quantize).max if quantize is not None else None, + static_scale=static_scale, + hidden_dim=hidden_dim, + BLOCK_SIZE=triton.next_power_of_2(hidden_dim), + **config, + ) + + if quantize is not None: + return out_hidden_states, out_scales + else: + return out_hidden_states, None diff --git a/python/sglang/srt/layers/moe/router.py b/python/sglang/srt/layers/moe/router.py new file mode 100644 index 000000000..504317afc --- /dev/null +++ b/python/sglang/srt/layers/moe/router.py @@ -0,0 +1,342 @@ +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.moe.topk import fused_topk + + +@triton.jit +def fused_moe_router_kernel( + input_ptr, # input (bs, hidden_dim) + moe_router_weight_ptr, # input (num_experts, hidden_dim) + topk_weights_ptr, # output (bs, topk) + topk_ids_ptr, # output (bs, topk) + num_experts: tl.constexpr, + topk: tl.constexpr, + moe_softcapping: tl.constexpr, + moe_renormalize: tl.constexpr, # not supported + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + offsets = tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_dim + + # moe_router_weight is k major + expert_offsets = tl.arange(0, num_experts)[:, None] + router_mask = mask[None, :] + w_router = tl.load( + moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :], + mask=router_mask, + other=0.0, + ) + + x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0) + + # todo: tl.dot? + logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1) + + # logit softcap + logits_scaled = logits / moe_softcapping + exped = tl.exp(2 * logits_scaled) + top = exped - 1 + bottom = exped + 1 + logits_softcapped = top / bottom * moe_softcapping + + # topk + # assert 1 <= topk <= num_experts + + # 5.38 us + + top1 = tl.argmax(logits_softcapped, axis=0) + tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us + + top1_v = tl.max(logits_softcapped, axis=0) + invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0) + + tl.store( + topk_weights_ptr + pid * topk + 0, + invsumexp, + ) # 5.73 us + + if topk >= 2: + top2 = tl.argmax( + tl.where( + tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf") + ), + axis=0, + ) + tl.store(topk_ids_ptr + pid * topk + 1, top2) + top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0) + tl.store( + topk_weights_ptr + pid * topk + 1, + tl.exp(top2_v - top1_v) * invsumexp, + ) # 5.95us + + # probably slow + if topk > 2: + topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype) + topk_mask = tl.where( + tl.arange(0, num_experts) != top1, topk_mask, float("-inf") + ) + topk_mask = tl.where( + tl.arange(0, num_experts) != top2, topk_mask, float("-inf") + ) + for i in range(2, topk): + topi = tl.argmax(logits_softcapped + topk_mask, axis=0) + topk_mask = tl.where( + tl.arange(0, num_experts) != topi, topk_mask, float("-inf") + ) + tl.store(topk_ids_ptr + pid * topk + i, topi) + topi_v = tl.sum( + logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0 + ) + tl.store( + topk_weights_ptr + pid * topk + i, + tl.exp(topi_v - top1_v) * invsumexp, + ) + # assert not moe_renormalize, "moe weight renormalization not implemented" + + +def fused_moe_router_impl( + x: torch.Tensor, + router_weight: torch.Tensor, + topk: int, + moe_softcapping: float, +): + assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1] + bs, hidden_dim = x.shape + num_experts = router_weight.shape[0] + + # router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device) + topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device) + topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) + + grid = lambda meta: (bs,) + config = { + "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), + "num_warps": max( + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 + ), + } + + fused_moe_router_kernel[grid]( + x, + router_weight, + topk_weights, + topk_ids, + num_experts=num_experts, + topk=topk, + moe_softcapping=moe_softcapping, + moe_renormalize=False, + hidden_dim=hidden_dim, + **config, + ) + + return topk_weights, topk_ids + + +@triton.jit +def fused_moe_router_large_bs_kernel( + a_ptr, # input (bs, hidden_dim) + b_ptr, # input (num_experts, hidden_dim) + topk_weights_ptr, # output (bs, topk) + topk_ids_ptr, # output (bs, topk) + bs, + num_experts: tl.constexpr, + topk: tl.constexpr, # only support topk == 1 + moe_softcapping: tl.constexpr, + moe_renormalize: tl.constexpr, # not supported + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + stride_am: tl.constexpr, + stride_bn: tl.constexpr, +): + + # 1. get block id + pid = tl.program_id(axis=0) + + # 2. create pointers for the first block of A and B + # 2.1. setup a_ptrs with offsets in m and k + offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None] + bs_mask = offs_m < bs + offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :] + a_ptrs = a_ptr + (offs_m * stride_am + offs_k) + + # 2.2. setup b_ptrs with offsets in k and n. + # Note: b matrix is k-major. + offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :] + offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None] + expert_mask = offs_n < num_experts + b_ptrs = b_ptr + (offs_n * stride_bn + offs_k) + + # 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N] + # 3.1. iterate in K dimension + # 3.2. transpose tile B + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0 + a = tl.load( + a_ptrs, + mask=bs_mask, + other=0.0, + ).to(tl.float32) + b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T + acc += tl.dot(a, b) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + + # 4. logit softcap + logits_scaled = acc / moe_softcapping + exped = tl.exp(2 * logits_scaled) + logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping + + # 5. top1 + cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts + top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1) + top1_v = tl.max( + tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True + ) + invsumexp = 1.0 / tl.sum( + tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1 + ) + + # 6. store to output + offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + topk_mask = offs_topk < bs + tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask) + tl.store( + topk_weights_ptr + offs_topk, + invsumexp, + mask=topk_mask, + ) + + +def fused_moe_router_large_bs_impl( + x: torch.Tensor, + router_weight: torch.Tensor, + topk: int, + moe_softcapping: float, + BLOCK_SIZE_M: int, + BLOCK_SIZE_N: int, + BLOCK_SIZE_K: int, +): + assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1] + bs, hidden_dim = x.shape + num_experts = router_weight.shape[0] + + assert num_experts <= BLOCK_SIZE_N + assert hidden_dim % BLOCK_SIZE_K == 0 + assert topk == 1 + + topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device) + topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) + + grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),) + + fused_moe_router_large_bs_kernel[grid]( + a_ptr=x, + b_ptr=router_weight, + topk_weights_ptr=topk_weights, + topk_ids_ptr=topk_ids, + bs=bs, + num_experts=num_experts, + topk=topk, + moe_softcapping=moe_softcapping, + moe_renormalize=False, + K=hidden_dim, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + stride_am=hidden_dim, + stride_bn=hidden_dim, + ) + + return topk_weights, topk_ids + + +def fused_moe_router_shim( + moe_softcapping, + hidden_states, + gating_output, + topk, + renormalize, +): + assert not renormalize + assert ( + len(hidden_states.shape) == 2 + and hidden_states.shape[1] == gating_output.shape[1] + ) + bs, hidden_dim = hidden_states.shape + num_experts = gating_output.shape[0] + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 16 + BLOCK_SIZE_K = 256 + if ( + bs >= 512 + and topk == 1 + and num_experts <= BLOCK_SIZE_N + and hidden_dim % BLOCK_SIZE_K == 0 + ): + return fused_moe_router_large_bs_impl( + x=hidden_states, + router_weight=gating_output, + topk=topk, + moe_softcapping=moe_softcapping, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + else: + return fused_moe_router_impl( + x=hidden_states, + router_weight=gating_output, + topk=topk, + moe_softcapping=moe_softcapping, + ) + + +class FusedMoeRouter: + def __init__(self, router_linear, topk, moe_softcapping) -> None: + self.router_linear = router_linear + self.topk = topk + self.moe_softcapping = moe_softcapping + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward( + self, x: torch.Tensor, residual: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if x.is_cuda: + return self.forward_cuda(x, residual) + else: + return self.forward_vllm(x, residual) + + def forward_cuda( + self, x: torch.Tensor, autotune=False + ) -> Tuple[torch.Tensor, torch.Tensor]: + return fused_moe_router_shim( + moe_softcapping=self.moe_softcapping, + hidden_states=x, + gating_output=self.router_linear.weight, + topk=self.topk, + renormalize=False, + ) + + def forward_vllm( + self, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # g, _ = self.router_linear.forward(x) + g = x.float() @ self.router_linear.weight.T.float() + + g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping + + return fused_topk(x, g, self.topk, False) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index ff56bcef8..2ef25daef 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -15,28 +15,36 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" +import functools +import json +import logging +import math +import os +import warnings +from typing import Iterable, Optional, Tuple -from typing import Iterable, List, Optional, Tuple - +import numpy as np import torch -import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( - MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.router import fused_moe_router_shim from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -44,47 +52,17 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix +from sglang.srt.utils import dump_to_file + +logger = logging.getLogger(__name__) -class Grok1MLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - reduce_results=True, - use_presharded_weights: bool = False, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, - [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=add_prefix("gate_up_proj", prefix), - use_presharded_weights=use_presharded_weights, - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - prefix=add_prefix("down_proj", prefix), - reduce_results=reduce_results, - use_presharded_weights=use_presharded_weights, - ) - self.act_fn = GeluAndMul(approximate="tanh") - - def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x +debug_tensor_dump_output_folder = None +debug_tensor_dump_inject = False class Grok1MoE(nn.Module): @@ -108,51 +86,55 @@ class Grok1MoE(nn.Module): tp_size: Optional[int] = None, reduce_results=True, use_presharded_weights: bool = False, - prefix: str = "", + inplace: bool = True, + no_combine: bool = False, ): super().__init__() self.hidden_size = hidden_size - # Gate always runs at half / full precision for now. + # Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961) self.gate = ReplicatedLinear( hidden_size, num_experts, bias=False, - params_dtype=params_dtype, + params_dtype=torch.float32, quant_config=None, - prefix=add_prefix("gate", prefix), ) self.router_logit_softcapping = getattr( config, "router_logit_softcapping", 30.0 ) - self.experts = FusedMoE( + custom_routing_function = functools.partial( + fused_moe_router_shim, self.router_logit_softcapping + ) + + kwargs = {} + if global_server_args_dict["enable_ep_moe"]: + MoEImpl = EPMoE + else: + MoEImpl = FusedMoE + kwargs["reduce_results"] = reduce_results + kwargs["use_presharded_weights"] = use_presharded_weights + kwargs["inplace"] = inplace + kwargs["no_combine"] = no_combine + + self.experts = MoEImpl( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - reduce_results=reduce_results, renormalize=False, quant_config=quant_config, tp_size=tp_size, + custom_routing_function=custom_routing_function, activation="gelu", - use_presharded_weights=use_presharded_weights, - prefix=add_prefix("experts", prefix), + **kwargs, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # NOTE: hidden_states can have either 1D or 2D shape. - orig_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - router_logits = 30.0 * F.tanh(router_logits / 30.0) - # need to assert self.gate.quant_method is unquantized - final_hidden_states = self.experts(hidden_states, router_logits) - return final_hidden_states.view(orig_shape) + return self.experts(hidden_states, self.gate.weight) class Grok1Attention(nn.Module): @@ -167,31 +149,33 @@ class Grok1Attention(nn.Module): rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, - prefix: str = "", + load_presharded_attn: bool = False, ) -> None: super().__init__() self.config = config self.layer_id = layer_id self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() + attn_tp_rank = get_tensor_model_parallel_rank() + attn_tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= attn_tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % attn_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) self.head_dim = getattr(config, "head_dim", 128) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + self.load_presharded_attn = load_presharded_attn self.qkv_proj = QKVParallelLinear( hidden_size, @@ -200,7 +184,9 @@ class Grok1Attention(nn.Module): self.total_num_kv_heads, bias=False, quant_config=quant_config, - prefix=add_prefix("qkv_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + load_presharded_attn=self.load_presharded_attn, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -208,7 +194,9 @@ class Grok1Attention(nn.Module): bias=False, quant_config=quant_config, reduce_results=reduce_results, - prefix=add_prefix("o_proj", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + use_presharded_weights=self.load_presharded_attn, ) self.rotary_emb = get_rope( self.head_dim, @@ -227,7 +215,6 @@ class Grok1Attention(nn.Module): num_kv_heads=self.num_kv_heads, layer_id=layer_id, logit_cap=logit_cap, - prefix=add_prefix("attn", prefix), ) def forward( @@ -236,10 +223,73 @@ class Grok1Attention(nn.Module): hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: + if hidden_states.shape[0] == 0: + assert ( + not self.o_proj.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return hidden_states + if debug_tensor_dump_output_folder: + dump_to_file( + debug_tensor_dump_output_folder, + f"attn_input_{self.layer_id}", + hidden_states, + ) + + if debug_tensor_dump_inject: + name = os.path.join( + debug_tensor_dump_output_folder, + f"jax_dump_attn_input_{self.layer_id}.npy", + ) + logger.info(f"Load {name} from jax.") + x = np.load(name) + hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to( + hidden_states + ) + qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) + + if debug_tensor_dump_output_folder: + num_tokens = q.shape[0] + num_heads_q = self.num_heads + head_dim = self.head_dim + num_heads_kv = k.numel() // (num_tokens * head_dim) + + dump_to_file( + debug_tensor_dump_output_folder, + f"q_{self.layer_id}", + tensor_model_parallel_all_gather( + q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1 + ).contiguous(), + ) + dump_to_file( + debug_tensor_dump_output_folder, + f"k_{self.layer_id}", + tensor_model_parallel_all_gather( + k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1 + ).contiguous(), + ) + dump_to_file( + debug_tensor_dump_output_folder, + f"v_{self.layer_id}", + tensor_model_parallel_all_gather( + v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1 + ).contiguous(), + ) + attn_output = self.attn(q, k, v, forward_batch) + + if debug_tensor_dump_output_folder: + dump_to_file( + debug_tensor_dump_output_folder, + f"attn_output_{self.layer_id}", + tensor_model_parallel_all_gather( + attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(), + dim=1, + ).contiguous(), + ) + output, _ = self.o_proj(attn_output) return output @@ -250,8 +300,9 @@ class Grok1DecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, - use_presharded_weights: bool = False, - prefix: str = "", + load_presharded_moe: bool = False, + load_presharded_attn: bool = False, + load_presharded_mlp: bool = False, ) -> None: super().__init__() self.num_experts = config.num_local_experts @@ -268,7 +319,8 @@ class Grok1DecoderLayer(nn.Module): layer_id=layer_id, rope_theta=rope_theta, quant_config=quant_config, - prefix=add_prefix("attn", prefix), + reduce_results=False, + load_presharded_attn=load_presharded_attn, ) self.block_sparse_moe = Grok1MoE( config=config, @@ -282,38 +334,68 @@ class Grok1DecoderLayer(nn.Module): ), quant_config=quant_config, reduce_results=True, - use_presharded_weights=use_presharded_weights, - prefix=add_prefix("block_sparse_moe", prefix), + use_presharded_weights=load_presharded_moe, + inplace=True, + no_combine=False, # just a suggestion to not combine topk ) + self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.ffn = self.block_sparse_moe + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, - ) -> torch.Tensor: + residual: Optional[torch.Tensor] = None, + deferred_norm: Optional[RMSNorm] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]: # Self Attention - hidden_states = ( - self.post_attn_norm( - self.self_attn( - positions=positions, - hidden_states=self.pre_attn_norm(hidden_states), - forward_batch=forward_batch, - ) + if deferred_norm is not None: + assert residual is not None + # here hidden_states is output of ffn, residual is residual from after previous attn layer + hidden_states, residual = fused_dual_residual_rmsnorm( + hidden_states, + residual, + deferred_norm.weight, + self.pre_attn_norm.weight, + deferred_norm.variance_epsilon, ) - + hidden_states + else: + # here hidden_states is the residual + hidden_states, residual = ( + fused_rmsnorm( + hidden_states, + self.pre_attn_norm.weight, + self.pre_attn_norm.variance_epsilon, + ), + hidden_states, + ) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + if get_tensor_model_parallel_world_size() > 1: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + hidden_states, residual = fused_dual_residual_rmsnorm( + hidden_states, + residual, + self.post_attn_norm.weight, + self.pre_moe_norm.weight, + self.post_attn_norm.variance_epsilon, ) # Fully Connected - hidden_states = ( - self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) - + hidden_states - ) - return hidden_states + hidden_states = self.ffn(hidden_states) + return hidden_states, residual, self.post_moe_norm # defer layernorm class Grok1Model(nn.Module): @@ -321,8 +403,10 @@ class Grok1Model(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - use_presharded_weights: bool = False, - prefix: str = "", + load_presharded_moe: bool = False, + load_presharded_embedding: bool = False, + load_presharded_attn: bool = False, + load_presharded_mlp: bool = False, ) -> None: super().__init__() self.config = config @@ -332,7 +416,7 @@ class Grok1Model(nn.Module): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - prefix=add_prefix("embed_tokens", prefix), + use_presharded_weights=load_presharded_embedding, ) self.layers = nn.ModuleList( [ @@ -340,8 +424,9 @@ class Grok1Model(nn.Module): config, i, quant_config=quant_config, - use_presharded_weights=use_presharded_weights, - prefix=add_prefix(f"layers.{i}", prefix), + load_presharded_moe=load_presharded_moe, + load_presharded_attn=load_presharded_attn, + load_presharded_mlp=load_presharded_mlp, ) for i in range(config.num_hidden_layers) ] @@ -361,10 +446,48 @@ class Grok1Model(nn.Module): else: hidden_states = input_embeds + residual, deferred_norm = None, None for i in range(len(self.layers)): - hidden_states = self.layers[i](positions, hidden_states, forward_batch) - hidden_states = self.norm(hidden_states) - hidden_states.mul_(self.config.output_multiplier_scale) + hidden_states, residual, deferred_norm = self.layers[i]( + positions, hidden_states, forward_batch, residual, deferred_norm + ) + + if debug_tensor_dump_output_folder: + hidden_states = ( + fused_rmsnorm( + hidden_states, + deferred_norm.weight, + deferred_norm.variance_epsilon, + ) + + residual + ) + + dump_to_file( + debug_tensor_dump_output_folder, + "last_hidden_before_norm", + hidden_states, + ) + + hidden_states = fused_rmsnorm( + hidden_states, + self.norm.weight, + self.norm.variance_epsilon, + ) + + dump_to_file( + debug_tensor_dump_output_folder, + "last_hidden_after_norm", + hidden_states, + ) + else: + hidden_states, _ = fused_dual_residual_rmsnorm( + hidden_states, + residual, + deferred_norm.weight, + self.norm.weight, + deferred_norm.variance_epsilon, + ) + return hidden_states @@ -373,31 +496,77 @@ class Grok1ForCausalLM(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - if ( + # Get presharded weights. + self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False) + self.load_presharded_moe = ( self.config.num_local_experts > 0 and get_tensor_model_parallel_world_size() > 1 - ): - self.use_presharded_weights = True + ) + self.load_presharded_attn = getattr(config, "load_presharded_attn", False) + self.load_presharded_embedding = getattr( + config, "load_presharded_embedding", False + ) + + self.is_weights_presharded = ( + self.load_presharded_mlp + or self.load_presharded_moe + or self.load_presharded_attn + or self.load_presharded_embedding + ) + + if self.is_weights_presharded: setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) - else: - self.use_presharded_weights = False + + default_replicate_lm_head = False + self.replicate_lm_head = getattr( + config, "replicate_lm_head", default_replicate_lm_head + ) self.model = Grok1Model( config, quant_config=quant_config, - use_presharded_weights=self.use_presharded_weights, - prefix=add_prefix("model", prefix), + load_presharded_moe=self.load_presharded_moe, + load_presharded_embedding=self.load_presharded_embedding, + load_presharded_attn=self.load_presharded_attn, + load_presharded_mlp=self.load_presharded_mlp, ) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix) - ) - self.logits_processor = LogitsProcessor(config) + + lm_head_params_dtype = None + if self.replicate_lm_head: + self.lm_head = ReplicatedLinear( + config.hidden_size, + config.vocab_size, + bias=False, + params_dtype=lm_head_params_dtype, + ) + self.logits_processor = LogitsProcessor(config, skip_all_gather=True) + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + use_presharded_weights=self.load_presharded_embedding, + params_dtype=lm_head_params_dtype, + ) + self.logits_processor = LogitsProcessor(config) + + # Dump tensors for debugging + global debug_tensor_dump_output_folder, debug_tensor_dump_inject + debug_tensor_dump_output_folder = global_server_args_dict[ + "debug_tensor_dump_output_folder" + ] + debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"] + warnings.filterwarnings("ignore", category=FutureWarning) + + if get_tensor_model_parallel_rank() == 0: + logger.info( + f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, " + f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B" + ) def forward( self, @@ -406,6 +575,9 @@ class Grok1ForCausalLM(nn.Module): forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: + if debug_tensor_dump_output_folder: + dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch @@ -414,21 +586,28 @@ class Grok1ForCausalLM(nn.Module): def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], - ): - num_experts = self.config.num_local_experts - - stacked_params_mapping = [ + num_experts: Optional[int] = None, + ignore_parent_name: bool = False, + ) -> dict[str, torch.Tensor]: + if num_experts is None: + num_experts = self.config.num_local_experts + stacked_params_mapping = [] + stacked_params_mapping += [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ] + stacked_params_mapping += [ + # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", @@ -439,14 +618,25 @@ class Grok1ForCausalLM(nn.Module): all_names = set(params_dict.keys()) hit_names = set() - def load_weight_wrapper(name, loaded_weight, *args, **kwargs): + def load_weight_wrapper( + name: str, loaded_weight: torch.Tensor, *args, **kwargs + ): + if ignore_parent_name: + name = name.split(".")[-1] + if name not in params_dict: return + # Fuse constant multipliers into the weights + if "lm_head" in name: + loaded_weight = ( + loaded_weight.to(torch.float32) + * self.config.output_multiplier_scale + ) + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight, *args, **kwargs) - hit_names.add(name) for name, loaded_weight in weights: @@ -460,7 +650,6 @@ class Grok1ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - load_weight_wrapper(name, loaded_weight, shard_id) break else: @@ -487,13 +676,79 @@ class Grok1ForCausalLM(nn.Module): load_weight_wrapper(name=name, loaded_weight=loaded_weight) + if len(hit_names) > 5: + missing = all_names - hit_names + missing_exclude_scales = {x for x in missing if "scale" not in x} + logger.info( + f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}", + ) + if len(missing_exclude_scales) > 0: + raise ValueError( + f"load_weights failed because some weights are missing: {missing_exclude_scales=}." + ) + + elif len(hit_names) == 0: + raise ValueError("load_weights failed because it did not hit any names.") + + return hit_names + + def get_num_params_analytical(self): + cfg = self.config + moe_intermediate_size = getattr( + cfg, + "moe_intermediate_size", + getattr(cfg, "intermediate_size", None), + ) + num_experts = cfg.num_local_experts + + wq = ( + cfg.num_hidden_layers + * cfg.hidden_size + * cfg.num_attention_heads + * cfg.head_dim + ) + wkv = ( + cfg.num_hidden_layers + * cfg.hidden_size + * cfg.num_key_value_heads + * cfg.head_dim + * 2 + ) + out = ( + cfg.num_hidden_layers + * cfg.hidden_size + * cfg.num_attention_heads + * cfg.head_dim + ) + ffn1 = ( + cfg.num_hidden_layers + * num_experts + * cfg.hidden_size + * moe_intermediate_size + * 2 + ) + ffn2 = ( + cfg.num_hidden_layers + * num_experts + * cfg.hidden_size + * moe_intermediate_size + ) + embed = cfg.hidden_size * cfg.vocab_size * 2 + return wq + wkv + out + ffn1 + ffn2 + embed + + def get_num_params_torch(self): + return ( + sum(p.numel() for p in self.parameters()) + * get_tensor_model_parallel_world_size() + ) + old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") def _prepare_presharded_weights( self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool -) -> Tuple[str, List[str], bool]: +) -> Tuple[str, list[str], bool]: import glob import os @@ -522,7 +777,7 @@ def _prepare_presharded_weights( # The new format allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"] - hf_weights_files: List[str] = [] + hf_weights_files = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))