Add graph runner support with torch compile on CPU (#7843)
This commit is contained in:
@@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
||||
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
|
||||
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
|
||||
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||
m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
||||
|
||||
// topk
|
||||
@@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
// decode
|
||||
m.def(
|
||||
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
|
||||
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor(a!) output, Tensor key, Tensor value, "
|
||||
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
|
||||
"float logit_cap) -> ()");
|
||||
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
|
||||
|
||||
// extend
|
||||
m.def(
|
||||
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
|
||||
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor(a!) o_extend, Tensor k_buffer, "
|
||||
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
|
||||
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
|
||||
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
|
||||
@@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
|
||||
|
||||
// bmm
|
||||
m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
|
||||
m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
|
||||
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
|
||||
|
||||
// moe
|
||||
@@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
|
||||
// all reduce
|
||||
m.def("initialize(int size, int rank) -> ()");
|
||||
m.def("shm_allreduce(Tensor data, int reduce_op) -> ()");
|
||||
m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()");
|
||||
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
|
||||
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||
|
||||
Reference in New Issue
Block a user