[AMD] Add test_fused_moe.py and test_rope_rocm.py to AMD CI (#5246)
This commit is contained in:
@@ -125,6 +125,7 @@ suites = {
|
||||
TestFile("test_chunked_prefill.py", 313),
|
||||
TestFile("test_eval_fp8_accuracy.py", 303),
|
||||
TestFile("test_function_call_parser.py", 10),
|
||||
TestFile("test_fused_moe.py", 30),
|
||||
TestFile("test_input_embeddings.py", 38),
|
||||
TestFile("test_metrics.py", 32),
|
||||
TestFile("test_no_chunked_prefill.py", 108),
|
||||
@@ -142,6 +143,7 @@ suites = {
|
||||
TestFile("test_vertex_endpoint.py", 31),
|
||||
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
|
||||
TestFile("test_reasoning_parser.py", 5),
|
||||
TestFile("test_rope_rocm.py", 3),
|
||||
],
|
||||
"per-commit-npu": [
|
||||
TestFile("test_ascend_attention_backend.py", 400),
|
||||
|
||||
@@ -6,8 +6,14 @@ from tqdm import tqdm
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.utils import is_hip
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
|
||||
class TestFusedMOE(CustomTestCase):
|
||||
NUM_EXPERTS = [8, 64]
|
||||
@@ -64,7 +70,7 @@ class TestFusedMOE(CustomTestCase):
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
if w1.dtype == torch.float8_e4m3fn:
|
||||
if w1.dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]:
|
||||
w1_compute = w1.to(a.dtype)
|
||||
w2_compute = w2.to(a.dtype)
|
||||
|
||||
@@ -97,7 +103,7 @@ class TestFusedMOE(CustomTestCase):
|
||||
if use_fp8_w8a8:
|
||||
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if not (capability[0] >= 9 or capability == (8, 9)):
|
||||
if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)):
|
||||
return
|
||||
|
||||
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||
@@ -106,12 +112,26 @@ class TestFusedMOE(CustomTestCase):
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||
|
||||
w1_scale = self.create_random_cuda_tensor(e, torch.float32)
|
||||
w2_scale = self.create_random_cuda_tensor(e, torch.float32)
|
||||
a1_scale = self.create_random_cuda_tensor(1, torch.float32)
|
||||
a2_scale = self.create_random_cuda_tensor(1, torch.float32)
|
||||
|
||||
# Handle HIP case: normalize float8 weights so fused kernel doesn't break
|
||||
# on ROCm.
|
||||
if _is_fp8_fnuz:
|
||||
# Normalize to e4m3fnuz on HIP
|
||||
w1, w1_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
input_scale=a1_scale,
|
||||
)
|
||||
w2, w2_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w2,
|
||||
weight_scale=w2_scale,
|
||||
input_scale=a2_scale,
|
||||
)
|
||||
|
||||
sglang_output = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
@@ -127,12 +147,19 @@ class TestFusedMOE(CustomTestCase):
|
||||
)
|
||||
|
||||
torch_output = self.torch_naive_moe(
|
||||
a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
sglang_output, torch_output, rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
else:
|
||||
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
|
||||
|
||||
116
test/srt/test_rope_rocm.py
Normal file
116
test/srt/test_rope_rocm.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.rotary_embedding import RotaryEmbedding
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
|
||||
_CASES = [
|
||||
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1),
|
||||
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2),
|
||||
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8),
|
||||
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4),
|
||||
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2),
|
||||
]
|
||||
|
||||
|
||||
@unittest.skipIf(_use_aiter, reason="SGLANG_USE_AITER=1 will not use vllm path.")
|
||||
class TestRotaryEmbeddingNative(CustomTestCase):
|
||||
# Compare RotaryEmbedding.forward_hip() to forward_native().
|
||||
def _run_case(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_pos: int,
|
||||
base: int,
|
||||
is_neox: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_q: int,
|
||||
num_kv: int,
|
||||
) -> None:
|
||||
rope_ref = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||
).to(device)
|
||||
rope_hip = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||
).to(device)
|
||||
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len, num_q * head_size, dtype=dtype, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len, num_kv * head_size, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
q_ref, k_ref = rope_ref.forward_native(pos_ids, query.clone(), key.clone())
|
||||
q_hip, k_hip = rope_hip.forward_hip(pos_ids, query.clone(), key.clone())
|
||||
|
||||
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_all_cases(self) -> None:
|
||||
"""Drive over the full parameter matrix using subTest()."""
|
||||
for case in _CASES:
|
||||
with self.subTest(case=case):
|
||||
self._run_case(*case)
|
||||
|
||||
|
||||
@unittest.skipIf(not _use_aiter, reason="Requires AMD GPU plus SGLANG_USE_AITER=1")
|
||||
class TestRotaryEmbeddingAITer(CustomTestCase):
|
||||
@staticmethod
|
||||
def _run_case_aiter(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_pos: int,
|
||||
base: int,
|
||||
is_neox: bool,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_q: int,
|
||||
num_kv: int,
|
||||
) -> None:
|
||||
from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding
|
||||
|
||||
rope_ref = AiterRotaryEmbedding(
|
||||
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||
).to(device)
|
||||
rope_hip = AiterRotaryEmbedding(
|
||||
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||
).to(device)
|
||||
|
||||
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||
query = torch.randn(
|
||||
batch_size * seq_len, num_q * head_size, dtype=dtype, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
batch_size * seq_len, num_kv * head_size, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
q_ref, k_ref = rope_ref.forward_native(pos_ids, query.clone(), key.clone())
|
||||
q_hip, k_hip = rope_hip.forward_hip(pos_ids, query.clone(), key.clone())
|
||||
|
||||
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
|
||||
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_all_cases(self) -> None:
|
||||
for case in _CASES:
|
||||
with self.subTest(case=case):
|
||||
self._run_case_aiter(*case)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user