[CPU] Fix build issue (#6419)
This commit is contained in:
@@ -5,9 +5,7 @@ set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
# Torch
|
||||
find_package(Torch REQUIRED)
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE}
|
||||
@@ -23,8 +21,9 @@ find_package(Torch REQUIRED)
|
||||
include_directories(
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_INSTALL_PREFIX}/include
|
||||
${Python3_INCLUDE_DIRS}
|
||||
${CMAKE_SOURCE_DIR}/csrc
|
||||
${Python_INCLUDE_DIRS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../csrc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
|
||||
# Platform-specific library directory
|
||||
@@ -39,23 +38,7 @@ else()
|
||||
endif()
|
||||
link_directories(${PLAT_LIB_DIR})
|
||||
|
||||
set(SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/activation.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/bmm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/decode.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extend.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/gemm_int8.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/moe.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/moe_int8.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/norm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/qkv_proj.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/topk.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/interface.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/shm.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/torch_extension_cpu.cpp
|
||||
)
|
||||
file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp")
|
||||
|
||||
add_compile_options(
|
||||
-O3
|
||||
@@ -64,24 +47,10 @@ add_compile_options(
|
||||
-fopenmp
|
||||
)
|
||||
|
||||
add_library(sgl_kernel_common_ops SHARED ${SOURCES})
|
||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES})
|
||||
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
|
||||
target_link_libraries(sgl_kernel_common_ops
|
||||
PRIVATE
|
||||
${TORCH_LIBRARIES}
|
||||
${Python3_LIBRARIES}
|
||||
c10
|
||||
)
|
||||
|
||||
set_target_properties(sgl_kernel_common_ops PROPERTIES
|
||||
INSTALL_RPATH "$ORIGIN/../../torch/lib"
|
||||
PREFIX ""
|
||||
OUTPUT_NAME "sgl_kernel.common_ops"
|
||||
)
|
||||
|
||||
target_compile_definitions(sgl_kernel_common_ops PRIVATE TORCH_API_INCLUDE_EXTENSION_H)
|
||||
|
||||
# Install
|
||||
install(TARGETS sgl_kernel_common_ops
|
||||
LIBRARY DESTINATION ${Python3_SITEARCH}
|
||||
install(TARGETS common_ops
|
||||
LIBRARY DESTINATION sgl_kernel
|
||||
)
|
||||
|
||||
@@ -74,7 +74,8 @@ void bmm_kernel_impl(
|
||||
// out : [B, M, N]
|
||||
// scale: [] 0-dim tensor for per tensor quant
|
||||
//
|
||||
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale) {
|
||||
void bmm_cpu(
|
||||
at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale) {
|
||||
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
@@ -463,7 +463,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) {
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
at::Tensor
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
@@ -482,7 +482,7 @@ at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size,
|
||||
std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
||||
|
||||
@@ -366,7 +366,7 @@ at::Tensor int8_scaled_mm_cpu(
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales1,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||
@@ -424,7 +424,7 @@ at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||
|
||||
@@ -11,7 +11,7 @@ static bool is_initialized = false;
|
||||
|
||||
static bool all_ranks_local_p = false;
|
||||
|
||||
void initialize(int size, int rank) {
|
||||
void initialize(int64_t size, int64_t rank) {
|
||||
if (is_initialized) {
|
||||
return;
|
||||
}
|
||||
@@ -47,12 +47,11 @@ void initialize(int size, int rank) {
|
||||
}
|
||||
}
|
||||
|
||||
void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op) {
|
||||
void shm_allreduce(
|
||||
torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op) {
|
||||
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
|
||||
|
||||
static py::object ReduceOp = py::module_::import("torch.distributed").attr("ReduceOp");
|
||||
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
|
||||
TORCH_CHECK(py::int_(op.attr("value")) == ReduceOpSum, "Only torch.distributed.ReduceOp.SUM is supported");
|
||||
TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");
|
||||
|
||||
auto numel = data.numel();
|
||||
|
||||
@@ -81,7 +80,7 @@ void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> p
|
||||
return;
|
||||
}
|
||||
|
||||
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim) {
|
||||
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim) {
|
||||
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
|
||||
|
||||
auto numel = data.numel();
|
||||
|
||||
@@ -946,10 +946,10 @@ at::Tensor fused_experts_cpu(
|
||||
at::Tensor& topk_ids,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::fused_experts_cpu", std::vector<c10::IValue>({hidden_states, w1, w2, topk_weights, topk_ids}));
|
||||
@@ -1138,11 +1138,11 @@ at::Tensor shared_expert_cpu(
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<std::vector<int64_t>> block_size,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector<c10::IValue>({hidden_states, w1, w2}));
|
||||
|
||||
|
||||
@@ -308,18 +308,18 @@ void rotary_emb_kernel_impl(
|
||||
} // anonymous namespace
|
||||
|
||||
extern at::Tensor
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni);
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
|
||||
|
||||
extern at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
extern void
|
||||
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale);
|
||||
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
|
||||
|
||||
// NB: shapes in DeepDeek R1
|
||||
//
|
||||
@@ -343,9 +343,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
std::optional<at::Tensor>& q_a_proj_scale,
|
||||
std::optional<at::Tensor>& q_b_proj_scale,
|
||||
std::optional<at::Tensor>& kv_a_proj_scale,
|
||||
std::optional<at::Tensor> q_a_proj_scale,
|
||||
std::optional<at::Tensor> q_b_proj_scale,
|
||||
std::optional<at::Tensor> kv_a_proj_scale,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::qkv_proj_with_rope",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include <torch/torch.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
#include "shm.h"
|
||||
|
||||
// silu_and_mul
|
||||
@@ -85,7 +86,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);
|
||||
|
||||
// gemm
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni);
|
||||
at::Tensor
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
|
||||
|
||||
// igemm
|
||||
at::Tensor int8_scaled_mm_cpu(
|
||||
@@ -93,7 +95,7 @@ at::Tensor int8_scaled_mm_cpu(
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales1,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
@@ -103,7 +105,7 @@ at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size,
|
||||
std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
@@ -112,12 +114,12 @@ at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// bmm
|
||||
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale);
|
||||
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
|
||||
|
||||
// fused moe
|
||||
at::Tensor fused_experts_cpu(
|
||||
@@ -128,10 +130,10 @@ at::Tensor fused_experts_cpu(
|
||||
at::Tensor& topk_ids,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
|
||||
at::Tensor shared_expert_cpu(
|
||||
@@ -143,11 +145,11 @@ at::Tensor shared_expert_cpu(
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<std::vector<int64_t>> block_size,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
|
||||
// weight absorption
|
||||
@@ -163,80 +165,130 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
std::optional<at::Tensor>& q_a_proj_scale,
|
||||
std::optional<at::Tensor>& q_b_proj_scale,
|
||||
std::optional<at::Tensor>& kv_a_proj_scale,
|
||||
std::optional<at::Tensor> q_a_proj_scale,
|
||||
std::optional<at::Tensor> q_b_proj_scale,
|
||||
std::optional<at::Tensor> kv_a_proj_scale,
|
||||
bool is_vnni);
|
||||
|
||||
// shared memory init
|
||||
void initialize(int size, int rank);
|
||||
void initialize(int64_t size, int64_t rank);
|
||||
|
||||
// shared mmeory all_reduce
|
||||
void shm_allreduce(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op);
|
||||
void shm_allreduce(
|
||||
at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
|
||||
|
||||
// shared memory all_gather
|
||||
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim);
|
||||
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
|
||||
|
||||
// rope
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// activation
|
||||
m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU");
|
||||
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
|
||||
|
||||
// norm
|
||||
m.def("rmsnorm_cpu", &rmsnorm_cpu, "Root mean square normalization for CPU");
|
||||
m.def("fused_add_rmsnorm_cpu", &fused_add_rmsnorm_cpu, "Fused add root mean square normalization for CPU");
|
||||
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
||||
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
||||
|
||||
// topk
|
||||
m.def("grouped_topk_cpu", &grouped_topk_cpu, "Grouped TopK for CPU");
|
||||
m.def(
|
||||
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
|
||||
"int topk_group) -> (Tensor, Tensor)");
|
||||
m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu);
|
||||
|
||||
// biased group topk
|
||||
m.def("biased_grouped_topk_cpu", &biased_grouped_topk_cpu, "Biased Grouped TopK for CPU");
|
||||
m.def(
|
||||
"biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool "
|
||||
"renormalize, int num_expert_group, int topk_group) -> (Tensor, Tensor)");
|
||||
m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu);
|
||||
|
||||
// decode
|
||||
m.def("decode_attention_cpu", &decode_attention_cpu, "Attention decoding for CPU");
|
||||
m.def(
|
||||
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
|
||||
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()");
|
||||
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
|
||||
|
||||
// extend
|
||||
m.def("extend_attention_cpu", &extend_attention_cpu, "Attention extend for CPU");
|
||||
m.def(
|
||||
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
|
||||
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
|
||||
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
|
||||
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
|
||||
|
||||
// weight prepack
|
||||
m.def("convert_weight_packed", &convert_weight_packed, "prepack weight to vnni format for intel AMX");
|
||||
m.def("convert_weight_packed(Tensor weight) -> Tensor");
|
||||
m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
|
||||
|
||||
// quant
|
||||
m.def("per_token_quant_int8_cpu", &per_token_quant_int8_cpu, "dynamic quantization for CPU");
|
||||
m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)");
|
||||
m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu);
|
||||
|
||||
// gemm
|
||||
m.def("weight_packed_linear", &weight_packed_linear, "weight packed linear for intel AMX");
|
||||
m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor");
|
||||
m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
|
||||
|
||||
// igemm
|
||||
m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX");
|
||||
m.def(
|
||||
"int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType "
|
||||
"out_dtype, bool is_vnni) -> Tensor");
|
||||
m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu);
|
||||
|
||||
// fp8 gemm
|
||||
m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX");
|
||||
m.def(
|
||||
"fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType "
|
||||
"out_dtype, bool is_vnni) -> Tensor");
|
||||
m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
|
||||
|
||||
// quant + igemm
|
||||
m.def(
|
||||
"int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX");
|
||||
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool "
|
||||
"is_vnni) -> Tensor");
|
||||
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
|
||||
|
||||
// bmm
|
||||
m.def("bmm_cpu", &bmm_cpu, "bmm kernel for intel AMX");
|
||||
m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
|
||||
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
|
||||
|
||||
// moe
|
||||
m.def("fused_experts_cpu", &fused_experts_cpu, "fused moe kernel for CPU");
|
||||
m.def(
|
||||
"fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
|
||||
"inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool "
|
||||
"is_vnni) -> Tensor");
|
||||
m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||
|
||||
// weight absorption
|
||||
m.def("qkv_proj_with_rope", &qkv_proj_with_rope, "fused qkv projection kernel with weight absorption for intel AMX");
|
||||
m.def(
|
||||
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
|
||||
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
|
||||
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
|
||||
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)");
|
||||
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
|
||||
|
||||
// shared expert
|
||||
m.def("shared_expert_cpu", &shared_expert_cpu, "shared expert kernel for CPU");
|
||||
m.def(
|
||||
"shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float "
|
||||
"routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? "
|
||||
"w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor");
|
||||
m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu);
|
||||
|
||||
// all reduce
|
||||
m.def("initialize", &initialize, "shared memory initialization for CPU");
|
||||
m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU");
|
||||
m.def("shm_allgather", &shm_allgather, "low latency all_gather implementation for CPU");
|
||||
m.def("initialize(int size, int rank) -> ()");
|
||||
m.impl("initialize", torch::kCPU, &initialize);
|
||||
m.def(
|
||||
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
|
||||
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||
m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor");
|
||||
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||
|
||||
// rope
|
||||
m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU");
|
||||
m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)");
|
||||
m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
@@ -34,7 +34,3 @@ exclude = [
|
||||
cmake.source-dir = "csrc/cpu"
|
||||
cmake.build-type = "Release"
|
||||
minimum-version = "build-system.requires"
|
||||
|
||||
wheel.py-api = "cp39"
|
||||
wheel.license-files = []
|
||||
wheel.packages = ["python/sgl_kernel"]
|
||||
|
||||
@@ -50,7 +50,9 @@ def _get_version():
|
||||
cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1"
|
||||
|
||||
operator_namespace = "sgl_kernel"
|
||||
include_dirs = []
|
||||
include_dirs = [
|
||||
"../../include",
|
||||
]
|
||||
|
||||
sources = [
|
||||
"csrc/cpu/activation.cpp",
|
||||
@@ -99,7 +101,7 @@ ext_modules = [
|
||||
extra_compile_args=extra_compile_args,
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=True,
|
||||
py_limited_api=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
from sgl_kernel.common_ops import (
|
||||
convert_weight_packed,
|
||||
fp8_scaled_mm_cpu,
|
||||
int8_scaled_mm_cpu,
|
||||
int8_scaled_mm_with_quant,
|
||||
per_token_quant_int8_cpu,
|
||||
weight_packed_linear,
|
||||
)
|
||||
from utils import (
|
||||
convert_weight,
|
||||
native_w8a8_per_token_matmul,
|
||||
@@ -58,10 +50,14 @@ class TestGemm(CustomTestCase):
|
||||
|
||||
ref = ref.bfloat16()
|
||||
|
||||
out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False)
|
||||
out = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
mat1, mat2, bias if has_bias else None, False
|
||||
)
|
||||
|
||||
packed_mat2 = convert_weight_packed(mat2)
|
||||
out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True)
|
||||
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
|
||||
out2 = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
mat1, packed_mat2, bias if has_bias else None, True
|
||||
)
|
||||
|
||||
atol = rtol = precision[ref.dtype]
|
||||
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
|
||||
@@ -100,14 +96,14 @@ class TestGemm(CustomTestCase):
|
||||
|
||||
atol = rtol = precision[ref_out.dtype]
|
||||
|
||||
Aq2, As2 = per_token_quant_int8_cpu(A)
|
||||
out = int8_scaled_mm_cpu(
|
||||
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
|
||||
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
|
||||
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||
)
|
||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||
|
||||
# test the fused version
|
||||
fused_out = int8_scaled_mm_with_quant(
|
||||
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
||||
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||
)
|
||||
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
|
||||
@@ -157,9 +153,9 @@ class TestGemm(CustomTestCase):
|
||||
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
|
||||
|
||||
if prepack:
|
||||
fp8_weight = convert_weight_packed(fp8_weight)
|
||||
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
|
||||
|
||||
opt = fp8_scaled_mm_cpu(
|
||||
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
||||
data,
|
||||
fp8_weight,
|
||||
scales,
|
||||
|
||||
@@ -2,12 +2,10 @@ import itertools
|
||||
import math
|
||||
import unittest
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
import sgl_kernel
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# TODO: use interface in cpu.py
|
||||
from sgl_kernel.common_ops import convert_weight_packed
|
||||
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
|
||||
from utils import (
|
||||
BLOCK_K,
|
||||
BLOCK_N,
|
||||
@@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase):
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res = shared_expert(
|
||||
res = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
@@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase):
|
||||
fused_output.float(),
|
||||
routed_scaling_factor,
|
||||
).to(dtype=dtype)
|
||||
res2 = shared_expert(
|
||||
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
hidden_states2,
|
||||
w1_q,
|
||||
w2_q,
|
||||
@@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase):
|
||||
ref_out = shared_out + fused_out.float() * routed_scaling_factor
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
|
||||
w1 = convert_weight_packed(w1) # [2N, K]
|
||||
w2 = convert_weight_packed(w2) # [K, N]
|
||||
out = shared_expert(
|
||||
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
|
||||
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
|
||||
out = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||
a2,
|
||||
w1,
|
||||
w2,
|
||||
|
||||
Reference in New Issue
Block a user