sgl-kernel use cutlass latest version for fp8 blockwise gemm (#5207)
This commit is contained in:
@@ -82,9 +82,9 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
|
||||
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 14080])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 14080, 16384])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_accuracy(M, N, K, out_dtype):
|
||||
_test_accuracy_once(M, N, K, out_dtype, "cuda")
|
||||
|
||||
Reference in New Issue
Block a user