[AMD] Add test_fused_moe.py and test_rope_rocm.py to AMD CI (#5246)
This commit is contained in:
@@ -125,6 +125,7 @@ suites = {
|
|||||||
TestFile("test_chunked_prefill.py", 313),
|
TestFile("test_chunked_prefill.py", 313),
|
||||||
TestFile("test_eval_fp8_accuracy.py", 303),
|
TestFile("test_eval_fp8_accuracy.py", 303),
|
||||||
TestFile("test_function_call_parser.py", 10),
|
TestFile("test_function_call_parser.py", 10),
|
||||||
|
TestFile("test_fused_moe.py", 30),
|
||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
TestFile("test_metrics.py", 32),
|
TestFile("test_metrics.py", 32),
|
||||||
TestFile("test_no_chunked_prefill.py", 108),
|
TestFile("test_no_chunked_prefill.py", 108),
|
||||||
@@ -142,6 +143,7 @@ suites = {
|
|||||||
TestFile("test_vertex_endpoint.py", 31),
|
TestFile("test_vertex_endpoint.py", 31),
|
||||||
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
|
# TestFile("test_vision_chunked_prefill.py", 175), # Disabled temporarily and track in #7701
|
||||||
TestFile("test_reasoning_parser.py", 5),
|
TestFile("test_reasoning_parser.py", 5),
|
||||||
|
TestFile("test_rope_rocm.py", 3),
|
||||||
],
|
],
|
||||||
"per-commit-npu": [
|
"per-commit-npu": [
|
||||||
TestFile("test_ascend_attention_backend.py", 400),
|
TestFile("test_ascend_attention_backend.py", 400),
|
||||||
|
|||||||
@@ -6,8 +6,14 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
||||||
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
|
from sglang.srt.utils import is_hip
|
||||||
from sglang.test.test_utils import CustomTestCase
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
|
|
||||||
class TestFusedMOE(CustomTestCase):
|
class TestFusedMOE(CustomTestCase):
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64]
|
||||||
@@ -64,7 +70,7 @@ class TestFusedMOE(CustomTestCase):
|
|||||||
topk_weight = topk_weight.view(-1)
|
topk_weight = topk_weight.view(-1)
|
||||||
topk_ids = topk_ids.view(-1)
|
topk_ids = topk_ids.view(-1)
|
||||||
|
|
||||||
if w1.dtype == torch.float8_e4m3fn:
|
if w1.dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]:
|
||||||
w1_compute = w1.to(a.dtype)
|
w1_compute = w1.to(a.dtype)
|
||||||
w2_compute = w2.to(a.dtype)
|
w2_compute = w2.to(a.dtype)
|
||||||
|
|
||||||
@@ -97,7 +103,7 @@ class TestFusedMOE(CustomTestCase):
|
|||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
if not (capability[0] >= 9 or capability == (8, 9)):
|
if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)):
|
||||||
return
|
return
|
||||||
|
|
||||||
a = self.create_random_cuda_tensor((m, k), dtype)
|
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||||
@@ -106,12 +112,26 @@ class TestFusedMOE(CustomTestCase):
|
|||||||
w1 = w1.to(torch.float8_e4m3fn)
|
w1 = w1.to(torch.float8_e4m3fn)
|
||||||
w2 = w2.to(torch.float8_e4m3fn)
|
w2 = w2.to(torch.float8_e4m3fn)
|
||||||
score = self.create_random_cuda_tensor((m, e), dtype)
|
score = self.create_random_cuda_tensor((m, e), dtype)
|
||||||
|
|
||||||
w1_scale = self.create_random_cuda_tensor(e, torch.float32)
|
w1_scale = self.create_random_cuda_tensor(e, torch.float32)
|
||||||
w2_scale = self.create_random_cuda_tensor(e, torch.float32)
|
w2_scale = self.create_random_cuda_tensor(e, torch.float32)
|
||||||
a1_scale = self.create_random_cuda_tensor(1, torch.float32)
|
a1_scale = self.create_random_cuda_tensor(1, torch.float32)
|
||||||
a2_scale = self.create_random_cuda_tensor(1, torch.float32)
|
a2_scale = self.create_random_cuda_tensor(1, torch.float32)
|
||||||
|
|
||||||
|
# Handle HIP case: normalize float8 weights so fused kernel doesn't break
|
||||||
|
# on ROCm.
|
||||||
|
if _is_fp8_fnuz:
|
||||||
|
# Normalize to e4m3fnuz on HIP
|
||||||
|
w1, w1_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=w1,
|
||||||
|
weight_scale=w1_scale,
|
||||||
|
input_scale=a1_scale,
|
||||||
|
)
|
||||||
|
w2, w2_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=w2,
|
||||||
|
weight_scale=w2_scale,
|
||||||
|
input_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
sglang_output = fused_moe(
|
sglang_output = fused_moe(
|
||||||
a,
|
a,
|
||||||
w1,
|
w1,
|
||||||
@@ -127,12 +147,19 @@ class TestFusedMOE(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
torch_output = self.torch_naive_moe(
|
torch_output = self.torch_naive_moe(
|
||||||
a, w1, w2, score, topk, w1_scale, w2_scale, a1_scale, a2_scale
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a1_scale,
|
||||||
|
a2_scale,
|
||||||
)
|
)
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
sglang_output, torch_output, rtol=rtol, atol=atol
|
sglang_output, torch_output, rtol=rtol, atol=atol
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
a = self.create_random_cuda_tensor((m, k), dtype)
|
a = self.create_random_cuda_tensor((m, k), dtype)
|
||||||
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
|
w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype)
|
||||||
|
|||||||
116
test/srt/test_rope_rocm.py
Normal file
116
test/srt/test_rope_rocm.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.rotary_embedding import RotaryEmbedding
|
||||||
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
|
|
||||||
|
_CASES = [
|
||||||
|
(64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1),
|
||||||
|
(256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2),
|
||||||
|
(512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2),
|
||||||
|
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8),
|
||||||
|
(128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4),
|
||||||
|
(512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(_use_aiter, reason="SGLANG_USE_AITER=1 will not use vllm path.")
|
||||||
|
class TestRotaryEmbeddingNative(CustomTestCase):
|
||||||
|
# Compare RotaryEmbedding.forward_hip() to forward_native().
|
||||||
|
def _run_case(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_pos: int,
|
||||||
|
base: int,
|
||||||
|
is_neox: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
num_q: int,
|
||||||
|
num_kv: int,
|
||||||
|
) -> None:
|
||||||
|
rope_ref = RotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||||
|
).to(device)
|
||||||
|
rope_hip = RotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||||
|
query = torch.randn(
|
||||||
|
batch_size * seq_len, num_q * head_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
key = torch.randn(
|
||||||
|
batch_size * seq_len, num_kv * head_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
q_ref, k_ref = rope_ref.forward_native(pos_ids, query.clone(), key.clone())
|
||||||
|
q_hip, k_hip = rope_hip.forward_hip(pos_ids, query.clone(), key.clone())
|
||||||
|
|
||||||
|
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
|
||||||
|
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
def test_all_cases(self) -> None:
|
||||||
|
"""Drive over the full parameter matrix using subTest()."""
|
||||||
|
for case in _CASES:
|
||||||
|
with self.subTest(case=case):
|
||||||
|
self._run_case(*case)
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not _use_aiter, reason="Requires AMD GPU plus SGLANG_USE_AITER=1")
|
||||||
|
class TestRotaryEmbeddingAITer(CustomTestCase):
|
||||||
|
@staticmethod
|
||||||
|
def _run_case_aiter(
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_pos: int,
|
||||||
|
base: int,
|
||||||
|
is_neox: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
num_q: int,
|
||||||
|
num_kv: int,
|
||||||
|
) -> None:
|
||||||
|
from aiter.rotary_embedding import RotaryEmbedding as AiterRotaryEmbedding
|
||||||
|
|
||||||
|
rope_ref = AiterRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||||
|
).to(device)
|
||||||
|
rope_hip = AiterRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_pos, base, is_neox, dtype
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
pos_ids = torch.arange(seq_len, device=device).repeat(batch_size)
|
||||||
|
query = torch.randn(
|
||||||
|
batch_size * seq_len, num_q * head_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
key = torch.randn(
|
||||||
|
batch_size * seq_len, num_kv * head_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
q_ref, k_ref = rope_ref.forward_native(pos_ids, query.clone(), key.clone())
|
||||||
|
q_hip, k_hip = rope_hip.forward_hip(pos_ids, query.clone(), key.clone())
|
||||||
|
|
||||||
|
torch.testing.assert_close(q_ref, q_hip, atol=1e-2, rtol=1e-2)
|
||||||
|
torch.testing.assert_close(k_ref, k_hip, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
def test_all_cases(self) -> None:
|
||||||
|
for case in _CASES:
|
||||||
|
with self.subTest(case=case):
|
||||||
|
self._run_case_aiter(*case)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user