Fix run time error in ROCm platform (#5147)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: root <root@dell300x-pla-t10-17.pla.dcgpu>
This commit is contained in:
@@ -5,6 +5,9 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.topk import fused_topk
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -116,10 +119,13 @@ def fused_moe_router_impl(
|
||||
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
||||
|
||||
grid = lambda meta: (bs,)
|
||||
|
||||
min_num_warps = 16 if _is_hip else 32
|
||||
|
||||
config = {
|
||||
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
||||
"num_warps": max(
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
||||
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user