[AMD] Add test_fused_moe.py and test_rope_rocm.py to AMD CI (#5246)
This commit is contained in:
@@ -6,8 +6,14 @@ from tqdm import tqdm
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
|
||||
class TestFusedMOE(CustomTestCase):
|
||||
NUM_EXPERTS = [8, 64]
|
||||
@@ -64,7 +70,7 @@ class TestFusedMOE(CustomTestCase):
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
if w1.dtype == torch.float8_e4m3fn:
|
||||
if w1.dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]:
|
||||
w1_compute = w1.to(a.dtype)
|
||||
w2_compute = w2.to(a.dtype)
|
||||
|
||||
@@ -97,7 +103,7 @@ class TestFusedMOE(CustomTestCase):
|
||||
if use_fp8_w8a8:
|
||||
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if not (capability[0] >= 9 or capability == (8, 9)):
|
||||
if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)):
|
||||
return
|
||||
|
||||
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||
@@ -106,12 +112,26 @@ class TestFusedMOE(CustomTestCase):
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||
|
||||
w1_scale = self.create_random_cuda_tensor(e, torch.float32)
|
||||
w2_scale = self.create_random_cuda_tensor(e, torch.float32)
|
||||
a1_scale = self.create_random_cuda_tensor(1, torch.float32)
|
||||
a2_scale = self.create_random_cuda_tensor(1, torch.float32)
|
||||
|
||||
# Handle HIP case: normalize float8 weights so fused kernel doesn't break
|
||||
# on ROCm.
|
||||
if _is_fp8_fnuz:
|
||||
# Normalize to e4m3fnuz on HIP
|
||||
w1, w1_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
input_scale=a1_scale,
|
||||
)
|
||||
w2, w2_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w2,
|
||||
weight_scale=w2_scale,
|
||||
input_scale=a2_scale,
|
||||
)
|
||||
|
||||
sglang_output = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
@@ -127,12 +147,19 @@ class TestFusedMOE(CustomTestCase):
|
||||
)
|
||||
|
||||
torch_output = self.torch_naive_moe(
|
||||
a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
sglang_output, torch_output, rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
else:
|
||||
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
|
||||
|
||||
Reference in New Issue
Block a user