sgl-kernel use cutlass latest version for fp8 blockwise gemm (#5207)

This commit is contained in:
Yi Zhang
2025-04-10 02:47:04 +08:00
committed by GitHub
parent 7f875f1293
commit ebf495f013
6 changed files with 86 additions and 923 deletions

View File

@@ -82,9 +82,9 @@ def _test_accuracy_once(M, N, K, out_dtype, device):
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("M", [1, 3, 5, 127, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096, 8192, 14080])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 14080, 16384])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda")