Support sm90 Int8 gemm (#3035)

This commit is contained in:
Ke Bao
2025-01-21 22:21:54 +08:00
committed by GitHub
parent 5a0d680a14
commit 0ac019f171
2 changed files with 210 additions and 2 deletions

View File

@@ -25,7 +25,7 @@ class TestInt8Gemm(unittest.TestCase):
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
if with_bias:
bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
else:
bias = None