[Misc] clean up vllm in sgl-kernel test (#5189)
This commit is contained in:
@@ -4,7 +4,6 @@ from typing import Optional, Tuple
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import sgl_per_tensor_quant_fp8
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
@@ -12,13 +11,6 @@ is_hip_ = is_hip()
|
||||
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def vllm_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input, scale)
|
||||
|
||||
|
||||
def sglang_scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
@@ -34,6 +26,16 @@ def sglang_scaled_fp8_quant(
|
||||
return output, scale
|
||||
|
||||
|
||||
def torch_scaled_fp8_quant(tensor, inv_scale):
|
||||
# The reference implementation that fully aligns to
|
||||
# the kernel being tested.
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
scale = inv_scale.reciprocal()
|
||||
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
return qweight
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,hidden_dim",
|
||||
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
|
||||
@@ -45,21 +47,19 @@ def test_per_tensor_quant_compare_implementations(
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
|
||||
|
||||
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
|
||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
|
||||
torch_out = torch_scaled_fp8_quant(x, sglang_scale)
|
||||
|
||||
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
|
||||
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
scale = torch.rand(1, dtype=torch.float32, device=device)
|
||||
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x, scale)
|
||||
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale)
|
||||
torch_out = torch_scaled_fp8_quant(x, scale)
|
||||
|
||||
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
|
||||
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user