diff --git a/python/sglang/srt/layers/elementwise.py b/python/sglang/srt/layers/elementwise.py index 931dd2b9f..87d46f6cf 100644 --- a/python/sglang/srt/layers/elementwise.py +++ b/python/sglang/srt/layers/elementwise.py @@ -4,6 +4,10 @@ import torch import triton import triton.language as tl +from sglang.srt.utils import is_hip + +_is_hip = is_hip() + fused_softcap_autotune = triton.autotune( configs=[ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), @@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal assert x.shape == residual.shape and x.dtype == residual.dtype output, mid = torch.empty_like(x), torch.empty_like(x) bs, hidden_dim = x.shape + + min_num_warps = 16 if _is_hip else 32 + if autotune: fused_dual_residual_rmsnorm_kernel_autotune[(bs,)]( output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim @@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal config = { "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "num_warps": max( - min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 + min( + triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps + ), + 4, ), } @@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False): else: output = torch.empty_like(x) bs, hidden_dim = x.shape + + min_num_warps = 16 if _is_hip else 32 + config = { "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "num_warps": max( - min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4 ), } diff --git a/python/sglang/srt/layers/moe/router.py b/python/sglang/srt/layers/moe/router.py index 504317afc..ffa120cad 100644 --- a/python/sglang/srt/layers/moe/router.py +++ b/python/sglang/srt/layers/moe/router.py @@ -5,6 +5,9 @@ import triton import triton.language as tl from sglang.srt.layers.moe.topk import fused_topk +from sglang.srt.utils import is_hip + +_is_hip = is_hip() @triton.jit @@ -116,10 +119,13 @@ def fused_moe_router_impl( topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) grid = lambda meta: (bs,) + + min_num_warps = 16 if _is_hip else 32 + config = { "BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "num_warps": max( - min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 + min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4 ), } diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 7a7eb8884..574dffd63 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -171,6 +171,7 @@ def input_to_float8( amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) fp8_max = finfo.max if _is_hip: + dtype = torch.float8_e4m3fnuz fp8_max = 224.0 scale = fp8_max / amax x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)