Files
sglang/sgl-kernel/tests/test_per_token_group_quant_8bit.py

98 lines
2.9 KiB
Python

import itertools
import pytest
import torch
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.layers.quantization.utils import assert_fp8_all_close
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
@pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, dst_dtype, flags",
list(
itertools.product(
[127, 128, 512, 1024, 4096, 8192], # num_tokens
[256, 512, 1024, 2048, 4096], # hidden_dim
[8, 16, 32, 64, 128], # group_size
# TODO test int8
[fp8_type_], # dtype
[
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
),
],
)
),
)
def test_per_token_group_quant_with_column_major(
num_tokens,
hidden_dim,
group_size,
dst_dtype,
flags,
):
if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)):
pytest.skip()
return
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
pytest.skip("scale_ue8m0 only supported on Blackwell")
return
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
execute_kwargs = dict(
x=x,
group_size=group_size,
eps=1e-10,
dst_dtype=dst_dtype,
**flags,
)
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs)
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs)
# torch.set_printoptions(profile="full")
# print(f"{x_q_triton=}")
# print(f"{x_s_triton=}")
# print(f"{x_q_sglang=}")
# print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
assert_fp8_all_close(x_q_triton, x_q_sglang)
torch.testing.assert_close(
x_s_triton.contiguous(),
x_s_sglang.contiguous(),
rtol=1e-3,
atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
)
if __name__ == "__main__":
pytest.main([__file__])