Update fa3 interface and add unit test (#9150)

This commit is contained in:
Ke Bao
2025-08-13 20:05:02 +08:00
committed by GitHub
parent 3b3b3baf9f
commit 94f44b88d1
4 changed files with 54 additions and 12 deletions

View File

@@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor? scheduler_metadata,"
" int num_splits,"
" bool? pack_gqa,"
" int sm_margin) -> Tensor[]");
" int sm_margin,"
" Tensor? sinks) -> Tensor[]");
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
}