Update fa3 interface and add unit test (#9150)
This commit is contained in:
@@ -82,4 +82,5 @@ std::vector<at::Tensor> mha_fwd(
|
||||
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
|
||||
int num_splits,
|
||||
std::optional<bool> pack_gqa_,
|
||||
int const sm_margin);
|
||||
int const sm_margin,
|
||||
std::optional<const at::Tensor>& sinks_);
|
||||
|
||||
Reference in New Issue
Block a user