[Misc] clean up vllm in sgl-kernel test (#5189)
This commit is contained in:
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import awq_dequantize
|
from sgl_kernel import awq_dequantize
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
|
|
||||||
def reverse_awq_order(t: torch.Tensor):
|
def reverse_awq_order(t: torch.Tensor):
|
||||||
@@ -58,12 +57,6 @@ def awq_dequantize_torch(
|
|||||||
return (iweights - zeros) * scales
|
return (iweights - zeros) * scales
|
||||||
|
|
||||||
|
|
||||||
def vllm_awq_dequantize(
|
|
||||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
|
|
||||||
|
|
||||||
|
|
||||||
def sglang_awq_dequantize(
|
def sglang_awq_dequantize(
|
||||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -110,7 +103,6 @@ def test_awq_dequant_compare_implementations(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Run both implementations
|
# Run both implementations
|
||||||
vllm_out = vllm_awq_dequantize(qweight, scales.to(torch.float16), qzeros)
|
|
||||||
torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
|
torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
|
||||||
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
||||||
|
|
||||||
@@ -118,13 +110,6 @@ def test_awq_dequant_compare_implementations(
|
|||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||||
)
|
)
|
||||||
if not is_bf16_act:
|
|
||||||
torch.testing.assert_close(
|
|
||||||
vllm_out.to(torch.float32),
|
|
||||||
sglang_out.to(torch.float32),
|
|
||||||
rtol=1e-3,
|
|
||||||
atol=1e-5,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import int8_scaled_mm
|
from sgl_kernel import int8_scaled_mm
|
||||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
|
||||||
|
|
||||||
|
|
||||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -28,9 +27,7 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
|
|||||||
bias = None
|
bias = None
|
||||||
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
|
||||||
torch.testing.assert_close(o, o1)
|
torch.testing.assert_close(o, o1)
|
||||||
torch.testing.assert_close(o, o2)
|
|
||||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
@@ -12,13 +11,6 @@ is_hip_ = is_hip()
|
|||||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
def vllm_scaled_fp8_quant(
|
|
||||||
input: torch.Tensor,
|
|
||||||
scale: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
return ops.scaled_fp8_quant(input, scale)
|
|
||||||
|
|
||||||
|
|
||||||
def sglang_scaled_fp8_quant(
|
def sglang_scaled_fp8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
@@ -34,6 +26,16 @@ def sglang_scaled_fp8_quant(
|
|||||||
return output, scale
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
|
def torch_scaled_fp8_quant(tensor, inv_scale):
|
||||||
|
# The reference implementation that fully aligns to
|
||||||
|
# the kernel being tested.
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
scale = inv_scale.reciprocal()
|
||||||
|
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
qweight = qweight.to(torch.float8_e4m3fn)
|
||||||
|
return qweight
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_tokens,hidden_dim",
|
"num_tokens,hidden_dim",
|
||||||
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
|
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
|
||||||
@@ -45,21 +47,19 @@ def test_per_tensor_quant_compare_implementations(
|
|||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
|
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
|
||||||
|
|
||||||
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
|
|
||||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
|
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
|
||||||
|
torch_out = torch_scaled_fp8_quant(x, sglang_scale)
|
||||||
|
|
||||||
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
|
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
|
||||||
)
|
)
|
||||||
|
|
||||||
scale = torch.rand(1, dtype=torch.float32, device=device)
|
scale = torch.rand(1, dtype=torch.float32, device=device)
|
||||||
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x, scale)
|
|
||||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale)
|
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale)
|
||||||
|
torch_out = torch_scaled_fp8_quant(x, scale)
|
||||||
|
|
||||||
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
|
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import sgl_per_token_quant_fp8
|
from sgl_kernel import sgl_per_token_quant_fp8
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
@@ -12,10 +11,15 @@ is_hip_ = is_hip()
|
|||||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
def vllm_per_token_quant_fp8(
|
def torch_per_token_quant_fp8(tensor, inv_scale):
|
||||||
input: torch.Tensor,
|
# The reference implementation that fully aligns to
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
# the kernel being tested.
|
||||||
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
inv_scale = inv_scale.view(-1, 1)
|
||||||
|
scale = inv_scale.reciprocal()
|
||||||
|
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
qweight = qweight.to(torch.float8_e4m3fn)
|
||||||
|
return qweight
|
||||||
|
|
||||||
|
|
||||||
def sglang_per_token_quant_fp8(
|
def sglang_per_token_quant_fp8(
|
||||||
@@ -41,12 +45,11 @@ def test_per_token_quant_compare_implementations(
|
|||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
|
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
|
||||||
|
|
||||||
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
|
|
||||||
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
|
||||||
|
torch_out = torch_per_token_quant_fp8(x, sglang_scale)
|
||||||
|
|
||||||
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
|
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user