[Misc] clean up vllm in sgl-kernel test (#5189)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
@@ -28,9 +27,7 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
|
||||
bias = None
|
||||
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
torch.testing.assert_close(o, o1)
|
||||
torch.testing.assert_close(o, o2)
|
||||
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user