Add bf16 output option for dsv3_router_gemm kernel (#7999)

This commit is contained in:
Baizhou Zhang
2025-07-19 18:49:37 -07:00
committed by GitHub
parent 4540a4666a
commit 282eb59ff3
7 changed files with 465 additions and 104 deletions

View File

@@ -15,17 +15,20 @@ def test_dsv3_router_gemm(num_tokens):
mat_b = torch.randn(
(num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda"
).contiguous()
output = torch.empty(
(num_tokens, num_experts), dtype=torch.float32, device="cuda"
).contiguous()
ref = F.linear(mat_a, mat_b).to(torch.float32)
bf16_ref = F.linear(mat_a, mat_b)
float_ref = bf16_ref.to(torch.float32)
output = dsv3_router_gemm(mat_a, mat_b)
bf16_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
float_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
assert torch.allclose(
output, ref, rtol=1e-2, atol=1e-3
), "Router GEMM output mismatch with torch.nn.functional.linear reference"
bf16_output, bf16_ref, rtol=1e-2, atol=1e-3
), "Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference"
assert torch.allclose(
float_output, float_ref, rtol=1e-2, atol=1e-3
), "Router GEMM output in float32 dtype mismatch with torch.nn.functional.linear reference"
if __name__ == "__main__":