Fix run time error in ROCm platform (#5147)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: root <root@dell300x-pla-t10-17.pla.dcgpu>
This commit is contained in:
@@ -4,6 +4,10 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
fused_softcap_autotune = triton.autotune(
|
fused_softcap_autotune = triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
||||||
@@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
|||||||
assert x.shape == residual.shape and x.dtype == residual.dtype
|
assert x.shape == residual.shape and x.dtype == residual.dtype
|
||||||
output, mid = torch.empty_like(x), torch.empty_like(x)
|
output, mid = torch.empty_like(x), torch.empty_like(x)
|
||||||
bs, hidden_dim = x.shape
|
bs, hidden_dim = x.shape
|
||||||
|
|
||||||
|
min_num_warps = 16 if _is_hip else 32
|
||||||
|
|
||||||
if autotune:
|
if autotune:
|
||||||
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
||||||
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
||||||
@@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
|
|||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||||
"num_warps": max(
|
"num_warps": max(
|
||||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
min(
|
||||||
|
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
|
||||||
|
),
|
||||||
|
4,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
|||||||
else:
|
else:
|
||||||
output = torch.empty_like(x)
|
output = torch.empty_like(x)
|
||||||
bs, hidden_dim = x.shape
|
bs, hidden_dim = x.shape
|
||||||
|
|
||||||
|
min_num_warps = 16 if _is_hip else 32
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||||
"num_warps": max(
|
"num_warps": max(
|
||||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.moe.topk import fused_topk
|
from sglang.srt.layers.moe.topk import fused_topk
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
|
|||||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||||
|
|
||||||
grid = lambda meta: (bs,)
|
grid = lambda meta: (bs,)
|
||||||
|
|
||||||
|
min_num_warps = 16 if _is_hip else 32
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||||
"num_warps": max(
|
"num_warps": max(
|
||||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -171,6 +171,7 @@ def input_to_float8(
|
|||||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
fp8_max = finfo.max
|
fp8_max = finfo.max
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
|
dtype = torch.float8_e4m3fnuz
|
||||||
fp8_max = 224.0
|
fp8_max = 224.0
|
||||||
scale = fp8_max / amax
|
scale = fp8_max / amax
|
||||||
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
||||||
|
|||||||
Reference in New Issue
Block a user