Add dsv3 fused a gemm to sgl-kernel (#7630)

This commit is contained in:
Ke Bao
2025-06-29 17:52:24 +08:00
committed by GitHub
parent 071a1f51ae
commit 04b35190e2
9 changed files with 800 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import dsv3_fused_a_gemm
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
def test_dsv3_fused_a_gemm(num_tokens):
kHdIn = 7168
kHdOut = 2112
mat_a = torch.randn(
(num_tokens, kHdIn), dtype=torch.bfloat16, device="cuda"
).contiguous()
mat_b = torch.randn((kHdOut, kHdIn), dtype=torch.bfloat16, device="cuda").transpose(
0, 1
)
output = torch.empty(
(num_tokens, kHdOut), dtype=torch.bfloat16, device="cuda"
).contiguous()
ref = F.linear(mat_a, mat_b.T)
output = dsv3_fused_a_gemm(mat_a, mat_b)
assert torch.allclose(
output, ref, rtol=1e-2, atol=1e-3
), "Fused GEMM output mismatch with torch.nn.functional.linear reference"
if __name__ == "__main__":
pytest.main([__file__])