fix sgl-kernel unit tests (#5666)

This commit is contained in:
Yineng Zhang
2025-04-23 01:18:30 -07:00
committed by GitHub
parent e62c49557d
commit 15fabcc07f
9 changed files with 313 additions and 0 deletions

6
sgl-kernel/csrc/common_extension.cc Executable file → Normal file
View File

@@ -233,6 +233,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"bool is_causal, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse);
/*
* From XGrammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
}
REGISTER_EXTENSION(common_ops)