Add dsv3 router gemm kernel (#7627)

This commit is contained in:
Baizhou Zhang
2025-06-29 23:31:55 -07:00
committed by GitHub
parent 22352d47a9
commit 7248272ccc
8 changed files with 398 additions and 0 deletions

View File

@@ -158,6 +158,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor expert_offsets, Tensor sf_offsets) -> ()");
m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
/*
* From csrc/moe
*/