diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 33935fc14..25b253e6b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -125,6 +125,7 @@ suites = { TestFile("test_chunked_prefill.py", 313), TestFile("test_eval_fp8_accuracy.py", 303), TestFile("test_function_call_parser.py", 10), + TestFile("test_fused_moe.py", 30), TestFile("test_input_embeddings.py", 38), TestFile("test_metrics.py", 32), TestFile("test_no_chunked_prefill.py", 108), @@ -142,6 +143,7 @@ suites = { TestFile("test_vertex_endpoint.py", 31), # TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701 TestFile("test_reasoning_parser.py", 5), + TestFile("test_rope_rocm.py", 3), ], "per-commit-npu": [ TestFile("test_ascend_attention_backend.py", 400), diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index 9b6af04bc..d1c2735d1 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -6,8 +6,14 @@ from tqdm import tqdm from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz +from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.utils import is_hip from sglang.test.test_utils import CustomTestCase +_is_hip = is_hip() +_is_fp8_fnuz = is_fp8_fnuz() + class TestFusedMOE(CustomTestCase): NUM_EXPERTS = [8, 64] @@ -64,7 +70,7 @@ class TestFusedMOE(CustomTestCase): topk_weight = topk_weight.view(-1) topk_ids = topk_ids.view(-1) - if w1.dtype == torch.float8_e4m3fn: + if w1.dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]: w1_compute = w1.to(a.dtype) w2_compute = w2.to(a.dtype) @@ -97,7 +103,7 @@ class TestFusedMOE(CustomTestCase): if use_fp8_w8a8: # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 capability = torch.cuda.get_device_capability() - if not (capability[0] >= 9 or capability == (8, 9)): + if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)): return a = self.create_random_cuda_tensor((m, k), dtype) @@ -106,12 +112,26 @@ class TestFusedMOE(CustomTestCase): w1 = w1.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn) score = self.create_random_cuda_tensor((m, e), dtype) - w1_scale = self.create_random_cuda_tensor(e, torch.float32) w2_scale = self.create_random_cuda_tensor(e, torch.float32) a1_scale = self.create_random_cuda_tensor(1, torch.float32) a2_scale = self.create_random_cuda_tensor(1, torch.float32) + # Handle HIP case: normalize float8 weights so fused kernel doesn't break + # on ROCm. + if _is_fp8_fnuz: + # Normalize to e4m3fnuz on HIP + w1, w1_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w1, + weight_scale=w1_scale, + input_scale=a1_scale, + ) + w2, w2_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w2, + weight_scale=w2_scale, + input_scale=a2_scale, + ) + sglang_output = fused_moe( a, w1, @@ -127,12 +147,19 @@ class TestFusedMOE(CustomTestCase): ) torch_output = self.torch_naive_moe( - a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale + a, + w1, + w2, + score, + topk, + w1_scale, + w2_scale, + a1_scale, + a2_scale, ) torch.testing.assert_close( sglang_output, torch_output, rtol=rtol, atol=atol ) - else: a = self.create_random_cuda_tensor((m, k), dtype) w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) diff --git a/test/srt/test_rope_rocm.py b/test/srt/test_rope_rocm.py new file mode 100644 index 000000000..5850e7061 --- /dev/null +++ b/test/srt/test_rope_rocm.py @@ -0,0 +1,116 @@ +import unittest + +import torch + +from sglang.srt.layers.rotary_embedding import RotaryEmbedding +from sglang.srt.utils import get_bool_env_var, is_hip +from sglang.test.test_utils import CustomTestCase + +torch.manual_seed(0) + +_is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + + +_CASES = [ + (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), + (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), +] + + +@unittest.skipIf(_use_aiter, reason="SGLANG_USE_AITER=1 will not use vllm path.") +class TestRotaryEmbeddingNative(CustomTestCase): + # Compare RotaryEmbedding.forward_hip() to forward_native(). + def _run_case( + self, + head_size: int, + rotary_dim: int, + max_pos: int, + base: int, + is_neox: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q: int, + num_kv: int, + ) -> None: + rope_ref = RotaryEmbedding( + head_size, rotary_dim, max_pos, base, is_neox, dtype + ).to(device) + rope_hip = RotaryEmbedding( + head_size, rotary_dim, max_pos, base, is_neox, dtype + ).to(device) + + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_q * head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv * head_size, dtype=dtype, device=device + ) + + q_ref, k_ref = rope_ref.forward_native(pos_ids, query.clone(), key.clone()) + q_hip, k_hip = rope_hip.forward_hip(pos_ids, query.clone(), key.clone()) + + torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2) + + def test_all_cases(self) -> None: + """Drive over the full parameter matrix using subTest().""" + for case in _CASES: + with self.subTest(case=case): + self._run_case(*case) + + +@unittest.skipIf(not _use_aiter, reason="Requires AMD GPU plus SGLANG_USE_AITER=1") +class TestRotaryEmbeddingAITer(CustomTestCase): + @staticmethod + def _run_case_aiter( + head_size: int, + rotary_dim: int, + max_pos: int, + base: int, + is_neox: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q: int, + num_kv: int, + ) -> None: + from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding + + rope_ref = AiterRotaryEmbedding( + head_size, rotary_dim, max_pos, base, is_neox, dtype + ).to(device) + rope_hip = AiterRotaryEmbedding( + head_size, rotary_dim, max_pos, base, is_neox, dtype + ).to(device) + + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_q * head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv * head_size, dtype=dtype, device=device + ) + + q_ref, k_ref = rope_ref.forward_native(pos_ids, query.clone(), key.clone()) + q_hip, k_hip = rope_hip.forward_hip(pos_ids, query.clone(), key.clone()) + + torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2) + + def test_all_cases(self) -> None: + for case in _CASES: + with self.subTest(case=case): + self._run_case_aiter(*case) + + +if __name__ == "__main__": + unittest.main()