diff --git a/sgl-kernel/tests/test_awq_dequant.py b/sgl-kernel/tests/test_awq_dequant.py index 33380180b..60fc8a148 100644 --- a/sgl-kernel/tests/test_awq_dequant.py +++ b/sgl-kernel/tests/test_awq_dequant.py @@ -4,7 +4,6 @@ from typing import Optional, Tuple import pytest import torch from sgl_kernel import awq_dequantize -from vllm import _custom_ops as ops def reverse_awq_order(t: torch.Tensor): @@ -58,12 +57,6 @@ def awq_dequantize_torch( return (iweights - zeros) * scales -def vllm_awq_dequantize( - qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor -) -> torch.Tensor: - return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) - - def sglang_awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> torch.Tensor: @@ -110,7 +103,6 @@ def test_awq_dequant_compare_implementations( ) # Run both implementations - vllm_out = vllm_awq_dequantize(qweight, scales.to(torch.float16), qzeros) torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size) sglang_out = sglang_awq_dequantize(qweight, scales, qzeros) @@ -118,13 +110,6 @@ def test_awq_dequant_compare_implementations( torch.testing.assert_close( torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 ) - if not is_bf16_act: - torch.testing.assert_close( - vllm_out.to(torch.float32), - sglang_out.to(torch.float32), - rtol=1e-3, - atol=1e-5, - ) if __name__ == "__main__": diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index d87a9a5aa..9f4103e1d 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -1,7 +1,6 @@ import pytest import torch from sgl_kernel import int8_scaled_mm -from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm def to_int8(tensor: torch.Tensor) -> torch.Tensor: @@ -28,9 +27,7 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): bias = None o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) torch.testing.assert_close(o, o1) - torch.testing.assert_close(o, o2) print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") diff --git a/sgl-kernel/tests/test_per_tensor_quant_fp8.py b/sgl-kernel/tests/test_per_tensor_quant_fp8.py index 70b05af5d..620fa2dba 100644 --- a/sgl-kernel/tests/test_per_tensor_quant_fp8.py +++ b/sgl-kernel/tests/test_per_tensor_quant_fp8.py @@ -4,7 +4,6 @@ from typing import Optional, Tuple import pytest import torch from sgl_kernel import sgl_per_tensor_quant_fp8 -from vllm import _custom_ops as ops from sglang.srt.utils import is_hip @@ -12,13 +11,6 @@ is_hip_ = is_hip() fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn -def vllm_scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - return ops.scaled_fp8_quant(input, scale) - - def sglang_scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, @@ -34,6 +26,16 @@ def sglang_scaled_fp8_quant( return output, scale +def torch_scaled_fp8_quant(tensor, inv_scale): + # The reference implementation that fully aligns to + # the kernel being tested. + finfo = torch.finfo(torch.float8_e4m3fn) + scale = inv_scale.reciprocal() + qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) + qweight = qweight.to(torch.float8_e4m3fn) + return qweight + + @pytest.mark.parametrize( "num_tokens,hidden_dim", list(itertools.product([128, 256, 512], [512, 2048, 4096])), @@ -45,21 +47,19 @@ def test_per_tensor_quant_compare_implementations( device = torch.device("cuda") x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device) - vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) + torch_out = torch_scaled_fp8_quant(x, sglang_scale) - torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) torch.testing.assert_close( - vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3 + sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3 ) scale = torch.rand(1, dtype=torch.float32, device=device) - vllm_out, vllm_scale = vllm_scaled_fp8_quant(x, scale) sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale) + torch_out = torch_scaled_fp8_quant(x, scale) - torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) torch.testing.assert_close( - vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3 + sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3 ) diff --git a/sgl-kernel/tests/test_per_token_quant_fp8.py b/sgl-kernel/tests/test_per_token_quant_fp8.py index fe1e0afe3..00a80ca01 100644 --- a/sgl-kernel/tests/test_per_token_quant_fp8.py +++ b/sgl-kernel/tests/test_per_token_quant_fp8.py @@ -4,7 +4,6 @@ from typing import Optional, Tuple import pytest import torch from sgl_kernel import sgl_per_token_quant_fp8 -from vllm import _custom_ops as ops from sglang.srt.utils import is_hip @@ -12,10 +11,15 @@ is_hip_ = is_hip() fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn -def vllm_per_token_quant_fp8( - input: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True) +def torch_per_token_quant_fp8(tensor, inv_scale): + # The reference implementation that fully aligns to + # the kernel being tested. + finfo = torch.finfo(torch.float8_e4m3fn) + inv_scale = inv_scale.view(-1, 1) + scale = inv_scale.reciprocal() + qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max) + qweight = qweight.to(torch.float8_e4m3fn) + return qweight def sglang_per_token_quant_fp8( @@ -41,12 +45,11 @@ def test_per_token_quant_compare_implementations( device = torch.device("cuda") x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device) - vllm_out, vllm_scale = vllm_per_token_quant_fp8(x) sglang_out, sglang_scale = sglang_per_token_quant_fp8(x) + torch_out = torch_per_token_quant_fp8(x, sglang_scale) - torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3) torch.testing.assert_close( - vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3 + sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3 )