[1/2] Support Qserve (#6457)
Co-authored-by: yych0745 <1398089567@qq.com> Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
@@ -265,6 +265,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
*/
|
||||
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
|
||||
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
|
||||
|
||||
/*
|
||||
* From QServe
|
||||
*/
|
||||
m.def(
|
||||
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
|
||||
"Tensor _a_ssums, Tensor! _out_feats) -> ()");
|
||||
m.impl("qserve_w4a8_per_chn_gemm", torch::kCUDA, &qserve_w4a8_per_chn_gemm);
|
||||
|
||||
m.def(
|
||||
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
|
||||
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
||||
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user