[1/2] Speed up prefill mla attention (#10156)
This commit is contained in:
@@ -436,6 +436,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()");
|
||||
m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce);
|
||||
m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()");
|
||||
m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user