fix sgl-kernel unit tests (#5666)
This commit is contained in:
6
sgl-kernel/csrc/common_extension.cc
Executable file → Normal file
6
sgl-kernel/csrc/common_extension.cc
Executable file → Normal file
@@ -233,6 +233,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"bool is_causal, float softcap, bool return_softmax, "
|
||||
"Generator? gen) -> Tensor[]");
|
||||
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
|
||||
|
||||
/*
|
||||
* From XGrammar
|
||||
*/
|
||||
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
|
||||
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user