diff --git a/sgl-kernel/csrc/cpu/CMakeLists.txt b/sgl-kernel/csrc/cpu/CMakeLists.txt index fd3c8aae3..355a6ab47 100755 --- a/sgl-kernel/csrc/cpu/CMakeLists.txt +++ b/sgl-kernel/csrc/cpu/CMakeLists.txt @@ -5,9 +5,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -# Torch -find_package(Torch REQUIRED) -find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED) execute_process( COMMAND ${Python_EXECUTABLE} @@ -23,8 +21,9 @@ find_package(Torch REQUIRED) include_directories( ${TORCH_INCLUDE_DIRS} ${TORCH_INSTALL_PREFIX}/include - ${Python3_INCLUDE_DIRS} - ${CMAKE_SOURCE_DIR}/csrc + ${Python_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR}/../../csrc + ${CMAKE_CURRENT_SOURCE_DIR}/../../include ) # Platform-specific library directory @@ -39,23 +38,7 @@ else() endif() link_directories(${PLAT_LIB_DIR}) -set(SOURCES - ${CMAKE_CURRENT_SOURCE_DIR}/activation.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/bmm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/decode.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/extend.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemm_int8.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moe.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moe_int8.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/norm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/qkv_proj.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/topk.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/interface.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/shm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/torch_extension_cpu.cpp -) +file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") add_compile_options( -O3 @@ -64,24 +47,10 @@ add_compile_options( -fopenmp ) -add_library(sgl_kernel_common_ops SHARED ${SOURCES}) +Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) +target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES}) +target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS}) -target_link_libraries(sgl_kernel_common_ops - PRIVATE - ${TORCH_LIBRARIES} - ${Python3_LIBRARIES} - c10 -) - -set_target_properties(sgl_kernel_common_ops PROPERTIES - INSTALL_RPATH "$ORIGIN/../../torch/lib" - PREFIX "" - OUTPUT_NAME "sgl_kernel.common_ops" -) - -target_compile_definitions(sgl_kernel_common_ops PRIVATE TORCH_API_INCLUDE_EXTENSION_H) - -# Install -install(TARGETS sgl_kernel_common_ops - LIBRARY DESTINATION ${Python3_SITEARCH} +install(TARGETS common_ops + LIBRARY DESTINATION sgl_kernel ) diff --git a/sgl-kernel/csrc/cpu/bmm.cpp b/sgl-kernel/csrc/cpu/bmm.cpp index f7377a09c..9e809a464 100644 --- a/sgl-kernel/csrc/cpu/bmm.cpp +++ b/sgl-kernel/csrc/cpu/bmm.cpp @@ -74,7 +74,8 @@ void bmm_kernel_impl( // out : [B, M, N] // scale: [] 0-dim tensor for per tensor quant // -void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional& scale) { +void bmm_cpu( + at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional& scale) { RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector({out, mat1, mat2})); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index 68dbd4896..8cdebb9a2 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -463,7 +463,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { // bias : [N] // out : [M, N] // -at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional& bias, bool is_vnni) { +at::Tensor +weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional& bias, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector({mat1, mat2, bias})); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 20dfed2da..1feded107 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -482,7 +482,7 @@ at::Tensor fp8_scaled_mm_cpu( at::Tensor& mat2, at::Tensor& scales2, std::vector block_size, - std::optional& bias, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, block_size, bias})); diff --git a/sgl-kernel/csrc/cpu/gemm_int8.cpp b/sgl-kernel/csrc/cpu/gemm_int8.cpp index a7a87ce74..f0f013cd1 100644 --- a/sgl-kernel/csrc/cpu/gemm_int8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_int8.cpp @@ -366,7 +366,7 @@ at::Tensor int8_scaled_mm_cpu( at::Tensor& mat2, at::Tensor& scales1, at::Tensor& scales2, - std::optional& bias, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales1, scales2, bias})); @@ -424,7 +424,7 @@ at::Tensor int8_scaled_mm_with_quant( at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::optional& bias, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); diff --git a/sgl-kernel/csrc/cpu/interface.cpp b/sgl-kernel/csrc/cpu/interface.cpp index 633a3ada7..61d9686d6 100644 --- a/sgl-kernel/csrc/cpu/interface.cpp +++ b/sgl-kernel/csrc/cpu/interface.cpp @@ -11,7 +11,7 @@ static bool is_initialized = false; static bool all_ranks_local_p = false; -void initialize(int size, int rank) { +void initialize(int64_t size, int64_t rank) { if (is_initialized) { return; } @@ -47,12 +47,11 @@ void initialize(int size, int rank) { } } -void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr process_group, py::object op) { +void shm_allreduce( + torch::Tensor& data, c10::intrusive_ptr process_group, c10::intrusive_ptr op) { RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector({data})); - static py::object ReduceOp = py::module_::import("torch.distributed").attr("ReduceOp"); - static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); - TORCH_CHECK(py::int_(op.attr("value")) == ReduceOpSum, "Only torch.distributed.ReduceOp.SUM is supported"); + TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported"); auto numel = data.numel(); @@ -81,7 +80,7 @@ void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr p return; } -torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr process_group, int dim) { +torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr process_group, int64_t dim) { RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector({data})); auto numel = data.numel(); diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index 6e12f1e38..e1a9a9f85 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -946,10 +946,10 @@ at::Tensor fused_experts_cpu( at::Tensor& topk_ids, bool inplace, bool use_int8_w8a8, - std::optional& w1_scale, - std::optional& w2_scale, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni) { RECORD_FUNCTION( "sgl-kernel::fused_experts_cpu", std::vector({hidden_states, w1, w2, topk_weights, topk_ids})); @@ -1138,11 +1138,11 @@ at::Tensor shared_expert_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, - std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector({hidden_states, w1, w2})); diff --git a/sgl-kernel/csrc/cpu/qkv_proj.cpp b/sgl-kernel/csrc/cpu/qkv_proj.cpp index 959072878..1a5361941 100644 --- a/sgl-kernel/csrc/cpu/qkv_proj.cpp +++ b/sgl-kernel/csrc/cpu/qkv_proj.cpp @@ -308,18 +308,18 @@ void rotary_emb_kernel_impl( } // anonymous namespace extern at::Tensor -weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional& bias, bool is_vnni); +weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional& bias, bool is_vnni); extern at::Tensor int8_scaled_mm_with_quant( at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::optional& bias, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); extern void -bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional& scale); +bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional& scale); // NB: shapes in DeepDeek R1 // @@ -343,9 +343,9 @@ std::tuple qkv_proj_with_rope( at::Tensor& cos_sin_cache, double eps, bool use_int8_w8a8, - std::optional& q_a_proj_scale, - std::optional& q_b_proj_scale, - std::optional& kv_a_proj_scale, + std::optional q_a_proj_scale, + std::optional q_b_proj_scale, + std::optional kv_a_proj_scale, bool is_vnni) { RECORD_FUNCTION( "sgl-kernel::qkv_proj_with_rope", diff --git a/sgl-kernel/csrc/cpu/shm.h b/sgl-kernel/csrc/cpu/shm.h index d21fe3d36..4419222a1 100644 --- a/sgl-kernel/csrc/cpu/shm.h +++ b/sgl-kernel/csrc/cpu/shm.h @@ -1,4 +1,4 @@ -#include +#include #include diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index aa28c7ed8..bfc367606 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "sgl_kernel_ops.h" #include "shm.h" // silu_and_mul @@ -85,7 +86,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight); std::tuple per_token_quant_int8_cpu(at::Tensor& A); // gemm -at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional& bias, bool is_vnni); +at::Tensor +weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional& bias, bool is_vnni); // igemm at::Tensor int8_scaled_mm_cpu( @@ -93,7 +95,7 @@ at::Tensor int8_scaled_mm_cpu( at::Tensor& mat2, at::Tensor& scales1, at::Tensor& scales2, - std::optional& bias, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); @@ -103,7 +105,7 @@ at::Tensor fp8_scaled_mm_cpu( at::Tensor& mat2, at::Tensor& scales2, std::vector block_size, - std::optional& bias, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); @@ -112,12 +114,12 @@ at::Tensor int8_scaled_mm_with_quant( at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::optional& bias, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni); // bmm -void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional& scale); +void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional& scale); // fused moe at::Tensor fused_experts_cpu( @@ -128,10 +130,10 @@ at::Tensor fused_experts_cpu( at::Tensor& topk_ids, bool inplace, bool use_int8_w8a8, - std::optional& w1_scale, - std::optional& w2_scale, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); at::Tensor shared_expert_cpu( @@ -143,11 +145,11 @@ at::Tensor shared_expert_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, - std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); // weight absorption @@ -163,80 +165,130 @@ std::tuple qkv_proj_with_rope( at::Tensor& cos_sin_cache, double eps, bool use_int8_w8a8, - std::optional& q_a_proj_scale, - std::optional& q_b_proj_scale, - std::optional& kv_a_proj_scale, + std::optional q_a_proj_scale, + std::optional q_b_proj_scale, + std::optional kv_a_proj_scale, bool is_vnni); // shared memory init -void initialize(int size, int rank); +void initialize(int64_t size, int64_t rank); // shared mmeory all_reduce -void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, py::object op); +void shm_allreduce( + at::Tensor& data, c10::intrusive_ptr process_group, c10::intrusive_ptr op); // shared memory all_gather -at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr process_group, int dim); +at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr process_group, int64_t dim); // rope std::tuple rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // activation - m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU"); + m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); + m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu); // norm - m.def("rmsnorm_cpu", &rmsnorm_cpu, "Root mean square normalization for CPU"); - m.def("fused_add_rmsnorm_cpu", &fused_add_rmsnorm_cpu, "Fused add root mean square normalization for CPU"); + m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor"); + m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu); + m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu); // topk - m.def("grouped_topk_cpu", &grouped_topk_cpu, "Grouped TopK for CPU"); + m.def( + "grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, " + "int topk_group) -> (Tensor, Tensor)"); + m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu); // biased group topk - m.def("biased_grouped_topk_cpu", &biased_grouped_topk_cpu, "Biased Grouped TopK for CPU"); + m.def( + "biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool " + "renormalize, int num_expert_group, int topk_group) -> (Tensor, Tensor)"); + m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu); // decode - m.def("decode_attention_cpu", &decode_attention_cpu, "Attention decoding for CPU"); + m.def( + "decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor " + "req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()"); + m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu); // extend - m.def("extend_attention_cpu", &extend_attention_cpu, "Attention extend for CPU"); + m.def( + "extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, " + "Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor " + "extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()"); + m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu); // weight prepack - m.def("convert_weight_packed", &convert_weight_packed, "prepack weight to vnni format for intel AMX"); + m.def("convert_weight_packed(Tensor weight) -> Tensor"); + m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); // quant - m.def("per_token_quant_int8_cpu", &per_token_quant_int8_cpu, "dynamic quantization for CPU"); + m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)"); + m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu); // gemm - m.def("weight_packed_linear", &weight_packed_linear, "weight packed linear for intel AMX"); + m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor"); + m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear); // igemm - m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX"); + m.def( + "int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType " + "out_dtype, bool is_vnni) -> Tensor"); + m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu); // fp8 gemm - m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX"); + m.def( + "fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType " + "out_dtype, bool is_vnni) -> Tensor"); + m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu); // quant + igemm m.def( - "int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX"); + "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool " + "is_vnni) -> Tensor"); + m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant); // bmm - m.def("bmm_cpu", &bmm_cpu, "bmm kernel for intel AMX"); + m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"); + m.impl("bmm_cpu", torch::kCPU, &bmm_cpu); // moe - m.def("fused_experts_cpu", &fused_experts_cpu, "fused moe kernel for CPU"); + m.def( + "fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool " + "inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool " + "is_vnni) -> Tensor"); + m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); // weight absorption - m.def("qkv_proj_with_rope", &qkv_proj_with_rope, "fused qkv projection kernel with weight absorption for intel AMX"); + m.def( + "qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor " + "kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, " + "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? " + "kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)"); + m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope); // shared expert - m.def("shared_expert_cpu", &shared_expert_cpu, "shared expert kernel for CPU"); + m.def( + "shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float " + "routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? " + "w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor"); + m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu); // all reduce - m.def("initialize", &initialize, "shared memory initialization for CPU"); - m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU"); - m.def("shm_allgather", &shm_allgather, "low latency all_gather implementation for CPU"); + m.def("initialize(int size, int rank) -> ()"); + m.impl("initialize", torch::kCPU, &initialize); + m.def( + "shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, " + "__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()"); + m.impl("shm_allreduce", torch::kCPU, &shm_allreduce); + m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor"); + m.impl("shm_allgather", torch::kCPU, &shm_allgather); // rope - m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU"); + m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)"); + m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu); } + +REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/pyproject_cpu.toml b/sgl-kernel/pyproject_cpu.toml index 0c42c4032..a45054e72 100644 --- a/sgl-kernel/pyproject_cpu.toml +++ b/sgl-kernel/pyproject_cpu.toml @@ -34,7 +34,3 @@ exclude = [ cmake.source-dir = "csrc/cpu" cmake.build-type = "Release" minimum-version = "build-system.requires" - -wheel.py-api = "cp39" -wheel.license-files = [] -wheel.packages = ["python/sgl_kernel"] diff --git a/sgl-kernel/setup_cpu.py b/sgl-kernel/setup_cpu.py index b5f182dc2..9fc07700b 100644 --- a/sgl-kernel/setup_cpu.py +++ b/sgl-kernel/setup_cpu.py @@ -50,7 +50,9 @@ def _get_version(): cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1" operator_namespace = "sgl_kernel" -include_dirs = [] +include_dirs = [ + "../../include", +] sources = [ "csrc/cpu/activation.cpp", @@ -99,7 +101,7 @@ ext_modules = [ extra_compile_args=extra_compile_args, libraries=libraries, extra_link_args=extra_link_args, - py_limited_api=True, + py_limited_api=False, ), ] diff --git a/test/srt/cpu/test_gemm.py b/test/srt/cpu/test_gemm.py index cc94bd3a0..bb4094f0d 100644 --- a/test/srt/cpu/test_gemm.py +++ b/test/srt/cpu/test_gemm.py @@ -1,18 +1,10 @@ import itertools import unittest +# TODO: use interface in cpu.py +import sgl_kernel import torch import torch.nn as nn - -# TODO: use interface in cpu.py -from sgl_kernel.common_ops import ( - convert_weight_packed, - fp8_scaled_mm_cpu, - int8_scaled_mm_cpu, - int8_scaled_mm_with_quant, - per_token_quant_int8_cpu, - weight_packed_linear, -) from utils import ( convert_weight, native_w8a8_per_token_matmul, @@ -58,10 +50,14 @@ class TestGemm(CustomTestCase): ref = ref.bfloat16() - out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False) + out = torch.ops.sgl_kernel.weight_packed_linear( + mat1, mat2, bias if has_bias else None, False + ) - packed_mat2 = convert_weight_packed(mat2) - out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True) + packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2) + out2 = torch.ops.sgl_kernel.weight_packed_linear( + mat1, packed_mat2, bias if has_bias else None, True + ) atol = rtol = precision[ref.dtype] self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol)) @@ -100,14 +96,14 @@ class TestGemm(CustomTestCase): atol = rtol = precision[ref_out.dtype] - Aq2, As2 = per_token_quant_int8_cpu(A) - out = int8_scaled_mm_cpu( + Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A) + out = torch.ops.sgl_kernel.int8_scaled_mm_cpu( Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False ) self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) # test the fused version - fused_out = int8_scaled_mm_with_quant( + fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant( A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False ) self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol)) @@ -157,9 +153,9 @@ class TestGemm(CustomTestCase): ref = torch.matmul(data.to(A_dtype), dq_weight.T) if prepack: - fp8_weight = convert_weight_packed(fp8_weight) + fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight) - opt = fp8_scaled_mm_cpu( + opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu( data, fp8_weight, scales, diff --git a/test/srt/cpu/test_shared_expert.py b/test/srt/cpu/test_shared_expert.py index 900d985f4..ea048495c 100644 --- a/test/srt/cpu/test_shared_expert.py +++ b/test/srt/cpu/test_shared_expert.py @@ -2,12 +2,10 @@ import itertools import math import unittest +# TODO: use interface in cpu.py +import sgl_kernel import torch import torch.nn as nn - -# TODO: use interface in cpu.py -from sgl_kernel.common_ops import convert_weight_packed -from sgl_kernel.common_ops import shared_expert_cpu as shared_expert from utils import ( BLOCK_K, BLOCK_N, @@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase): fused_output.float(), routed_scaling_factor, ).to(dtype=dtype) - res = shared_expert( + res = torch.ops.sgl_kernel.shared_expert_cpu( hidden_states, w1, w2, @@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase): fused_output.float(), routed_scaling_factor, ).to(dtype=dtype) - res2 = shared_expert( + res2 = torch.ops.sgl_kernel.shared_expert_cpu( hidden_states2, w1_q, w2_q, @@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase): ref_out = shared_out + fused_out.float() * routed_scaling_factor ref_out = ref_out.to(dtype=dtype) - w1 = convert_weight_packed(w1) # [2N, K] - w2 = convert_weight_packed(w2) # [K, N] - out = shared_expert( + w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K] + w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N] + out = torch.ops.sgl_kernel.shared_expert_cpu( a2, w1, w2,