[Misc] clean up vllm in sgl-kernel test (#5189)

This commit is contained in:
yinfan98
2025-04-09 16:22:13 +08:00
committed by GitHub
parent 61970b08d8
commit d2e507df3c
4 changed files with 25 additions and 40 deletions

View File

@@ -4,7 +4,6 @@ from typing import Optional, Tuple
import pytest
import torch
from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops
from sglang.srt.utils import is_hip
@@ -12,10 +11,15 @@ is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
def vllm_per_token_quant_fp8(
input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
def torch_per_token_quant_fp8(tensor, inv_scale):
# The reference implementation that fully aligns to
# the kernel being tested.
finfo = torch.finfo(torch.float8_e4m3fn)
inv_scale = inv_scale.view(-1, 1)
scale = inv_scale.reciprocal()
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
qweight = qweight.to(torch.float8_e4m3fn)
return qweight
def sglang_per_token_quant_fp8(
@@ -41,12 +45,11 @@ def test_per_token_quant_compare_implementations(
device = torch.device("cuda")
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
torch_out = torch_per_token_quant_fp8(x, sglang_scale)
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
)