Update fa3 interface and add unit test (#9150)
This commit is contained in:
@@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
" Tensor? scheduler_metadata,"
|
||||
" int num_splits,"
|
||||
" bool? pack_gqa,"
|
||||
" int sm_margin) -> Tensor[]");
|
||||
" int sm_margin,"
|
||||
" Tensor? sinks) -> Tensor[]");
|
||||
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user