[feat] add fa3 in sgl-kernel (#4902)

Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
yinfan98
2025-03-31 03:57:10 +08:00
committed by GitHub
parent 9adf178cc2
commit 37c66ec856
7 changed files with 1300 additions and 0 deletions

View File

@@ -91,6 +91,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("top_p_renorm_probs", top_p_renorm_probs);
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
/*
* From flash-attention
*/
m.def("fwd", make_pytorch_shim(mha_fwd));
}
REGISTER_EXTENSION(common_ops)