[feat] add fa3 in sgl-kernel (#4902)
Co-authored-by: Sleepcoo <Sleepcoo@gmail.com>
This commit is contained in:
@@ -91,6 +91,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
m.def("top_p_renorm_probs", top_p_renorm_probs);
|
||||
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
|
||||
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
|
||||
|
||||
/*
|
||||
* From flash-attention
|
||||
*/
|
||||
m.def("fwd", make_pytorch_shim(mha_fwd));
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user