[Misc] clean up vllm in sgl-kernel test (#5189)

This commit is contained in:
yinfan98
2025-04-09 16:22:13 +08:00
committed by GitHub
parent 61970b08d8
commit d2e507df3c
4 changed files with 25 additions and 40 deletions

View File

@@ -1,7 +1,6 @@
import pytest
import torch
from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
@@ -28,9 +27,7 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
bias = None
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(o, o1)
torch.testing.assert_close(o, o2)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")