[1/2] Support Qserve (#6457)

Co-authored-by: yych0745 <1398089567@qq.com>
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
HandH1998
2025-05-22 10:48:59 +08:00
committed by GitHub
parent 6ce0ed073b
commit 4d643f6c7a
10 changed files with 2086 additions and 0 deletions

View File

@@ -265,6 +265,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
/*
* From QServe
*/
m.def(
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
"Tensor _a_ssums, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_chn_gemm", torch::kCUDA, &qserve_w4a8_per_chn_gemm);
m.def(
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
"Tensor _ascales, Tensor! _out_feats) -> ()");
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
}
REGISTER_EXTENSION(common_ops)