[CPU] Fix build issue (#6419)
This commit is contained in:
@@ -5,9 +5,7 @@ set(CMAKE_CXX_STANDARD 17)
|
|||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||||
|
|
||||||
# Torch
|
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||||
find_package(Torch REQUIRED)
|
|
||||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
|
||||||
|
|
||||||
execute_process(
|
execute_process(
|
||||||
COMMAND ${Python_EXECUTABLE}
|
COMMAND ${Python_EXECUTABLE}
|
||||||
@@ -23,8 +21,9 @@ find_package(Torch REQUIRED)
|
|||||||
include_directories(
|
include_directories(
|
||||||
${TORCH_INCLUDE_DIRS}
|
${TORCH_INCLUDE_DIRS}
|
||||||
${TORCH_INSTALL_PREFIX}/include
|
${TORCH_INSTALL_PREFIX}/include
|
||||||
${Python3_INCLUDE_DIRS}
|
${Python_INCLUDE_DIRS}
|
||||||
${CMAKE_SOURCE_DIR}/csrc
|
${CMAKE_CURRENT_SOURCE_DIR}/../../csrc
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||||
)
|
)
|
||||||
|
|
||||||
# Platform-specific library directory
|
# Platform-specific library directory
|
||||||
@@ -39,23 +38,7 @@ else()
|
|||||||
endif()
|
endif()
|
||||||
link_directories(${PLAT_LIB_DIR})
|
link_directories(${PLAT_LIB_DIR})
|
||||||
|
|
||||||
set(SOURCES
|
file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp")
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/activation.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/bmm.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/decode.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/extend.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/gemm_int8.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/moe.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/moe_int8.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/norm.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/qkv_proj.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/topk.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/interface.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/shm.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/torch_extension_cpu.cpp
|
|
||||||
)
|
|
||||||
|
|
||||||
add_compile_options(
|
add_compile_options(
|
||||||
-O3
|
-O3
|
||||||
@@ -64,24 +47,10 @@ add_compile_options(
|
|||||||
-fopenmp
|
-fopenmp
|
||||||
)
|
)
|
||||||
|
|
||||||
add_library(sgl_kernel_common_ops SHARED ${SOURCES})
|
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||||
|
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES})
|
||||||
|
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||||
|
|
||||||
target_link_libraries(sgl_kernel_common_ops
|
install(TARGETS common_ops
|
||||||
PRIVATE
|
LIBRARY DESTINATION sgl_kernel
|
||||||
${TORCH_LIBRARIES}
|
|
||||||
${Python3_LIBRARIES}
|
|
||||||
c10
|
|
||||||
)
|
|
||||||
|
|
||||||
set_target_properties(sgl_kernel_common_ops PROPERTIES
|
|
||||||
INSTALL_RPATH "$ORIGIN/../../torch/lib"
|
|
||||||
PREFIX ""
|
|
||||||
OUTPUT_NAME "sgl_kernel.common_ops"
|
|
||||||
)
|
|
||||||
|
|
||||||
target_compile_definitions(sgl_kernel_common_ops PRIVATE TORCH_API_INCLUDE_EXTENSION_H)
|
|
||||||
|
|
||||||
# Install
|
|
||||||
install(TARGETS sgl_kernel_common_ops
|
|
||||||
LIBRARY DESTINATION ${Python3_SITEARCH}
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -74,7 +74,8 @@ void bmm_kernel_impl(
|
|||||||
// out : [B, M, N]
|
// out : [B, M, N]
|
||||||
// scale: [] 0-dim tensor for per tensor quant
|
// scale: [] 0-dim tensor for per tensor quant
|
||||||
//
|
//
|
||||||
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale) {
|
void bmm_cpu(
|
||||||
|
at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale) {
|
||||||
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
|
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
|
||||||
|
|
||||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||||
|
|||||||
@@ -463,7 +463,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) {
|
|||||||
// bias : [N]
|
// bias : [N]
|
||||||
// out : [M, N]
|
// out : [M, N]
|
||||||
//
|
//
|
||||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni) {
|
at::Tensor
|
||||||
|
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||||
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||||
|
|
||||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||||
|
|||||||
@@ -482,7 +482,7 @@ at::Tensor fp8_scaled_mm_cpu(
|
|||||||
at::Tensor& mat2,
|
at::Tensor& mat2,
|
||||||
at::Tensor& scales2,
|
at::Tensor& scales2,
|
||||||
std::vector<int64_t> block_size,
|
std::vector<int64_t> block_size,
|
||||||
std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
at::ScalarType out_dtype,
|
at::ScalarType out_dtype,
|
||||||
bool is_vnni) {
|
bool is_vnni) {
|
||||||
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
||||||
|
|||||||
@@ -366,7 +366,7 @@ at::Tensor int8_scaled_mm_cpu(
|
|||||||
at::Tensor& mat2,
|
at::Tensor& mat2,
|
||||||
at::Tensor& scales1,
|
at::Tensor& scales1,
|
||||||
at::Tensor& scales2,
|
at::Tensor& scales2,
|
||||||
std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
at::ScalarType out_dtype,
|
at::ScalarType out_dtype,
|
||||||
bool is_vnni) {
|
bool is_vnni) {
|
||||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||||
@@ -424,7 +424,7 @@ at::Tensor int8_scaled_mm_with_quant(
|
|||||||
at::Tensor& mat1,
|
at::Tensor& mat1,
|
||||||
at::Tensor& mat2,
|
at::Tensor& mat2,
|
||||||
at::Tensor& scales2,
|
at::Tensor& scales2,
|
||||||
std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
at::ScalarType out_dtype,
|
at::ScalarType out_dtype,
|
||||||
bool is_vnni) {
|
bool is_vnni) {
|
||||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ static bool is_initialized = false;
|
|||||||
|
|
||||||
static bool all_ranks_local_p = false;
|
static bool all_ranks_local_p = false;
|
||||||
|
|
||||||
void initialize(int size, int rank) {
|
void initialize(int64_t size, int64_t rank) {
|
||||||
if (is_initialized) {
|
if (is_initialized) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -47,12 +47,11 @@ void initialize(int size, int rank) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op) {
|
void shm_allreduce(
|
||||||
|
torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op) {
|
||||||
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
|
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
|
||||||
|
|
||||||
static py::object ReduceOp = py::module_::import("torch.distributed").attr("ReduceOp");
|
TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");
|
||||||
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
|
|
||||||
TORCH_CHECK(py::int_(op.attr("value")) == ReduceOpSum, "Only torch.distributed.ReduceOp.SUM is supported");
|
|
||||||
|
|
||||||
auto numel = data.numel();
|
auto numel = data.numel();
|
||||||
|
|
||||||
@@ -81,7 +80,7 @@ void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> p
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim) {
|
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim) {
|
||||||
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
|
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
|
||||||
|
|
||||||
auto numel = data.numel();
|
auto numel = data.numel();
|
||||||
|
|||||||
@@ -946,10 +946,10 @@ at::Tensor fused_experts_cpu(
|
|||||||
at::Tensor& topk_ids,
|
at::Tensor& topk_ids,
|
||||||
bool inplace,
|
bool inplace,
|
||||||
bool use_int8_w8a8,
|
bool use_int8_w8a8,
|
||||||
std::optional<at::Tensor>& w1_scale,
|
const std::optional<at::Tensor>& w1_scale,
|
||||||
std::optional<at::Tensor>& w2_scale,
|
const std::optional<at::Tensor>& w2_scale,
|
||||||
std::optional<at::Tensor>& a1_scale,
|
const std::optional<at::Tensor>& a1_scale,
|
||||||
std::optional<at::Tensor>& a2_scale,
|
const std::optional<at::Tensor>& a2_scale,
|
||||||
bool is_vnni) {
|
bool is_vnni) {
|
||||||
RECORD_FUNCTION(
|
RECORD_FUNCTION(
|
||||||
"sgl-kernel::fused_experts_cpu", std::vector<c10::IValue>({hidden_states, w1, w2, topk_weights, topk_ids}));
|
"sgl-kernel::fused_experts_cpu", std::vector<c10::IValue>({hidden_states, w1, w2, topk_weights, topk_ids}));
|
||||||
@@ -1138,11 +1138,11 @@ at::Tensor shared_expert_cpu(
|
|||||||
bool inplace,
|
bool inplace,
|
||||||
bool use_int8_w8a8,
|
bool use_int8_w8a8,
|
||||||
bool use_fp8_w8a16,
|
bool use_fp8_w8a16,
|
||||||
std::optional<at::Tensor>& w1_scale,
|
const std::optional<at::Tensor>& w1_scale,
|
||||||
std::optional<at::Tensor>& w2_scale,
|
const std::optional<at::Tensor>& w2_scale,
|
||||||
std::optional<std::vector<int64_t>> block_size,
|
const std::optional<std::vector<int64_t>> block_size,
|
||||||
std::optional<at::Tensor>& a1_scale,
|
const std::optional<at::Tensor>& a1_scale,
|
||||||
std::optional<at::Tensor>& a2_scale,
|
const std::optional<at::Tensor>& a2_scale,
|
||||||
bool is_vnni) {
|
bool is_vnni) {
|
||||||
RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector<c10::IValue>({hidden_states, w1, w2}));
|
RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector<c10::IValue>({hidden_states, w1, w2}));
|
||||||
|
|
||||||
|
|||||||
@@ -308,18 +308,18 @@ void rotary_emb_kernel_impl(
|
|||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
extern at::Tensor
|
extern at::Tensor
|
||||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni);
|
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
|
||||||
|
|
||||||
extern at::Tensor int8_scaled_mm_with_quant(
|
extern at::Tensor int8_scaled_mm_with_quant(
|
||||||
at::Tensor& mat1,
|
at::Tensor& mat1,
|
||||||
at::Tensor& mat2,
|
at::Tensor& mat2,
|
||||||
at::Tensor& scales2,
|
at::Tensor& scales2,
|
||||||
std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
at::ScalarType out_dtype,
|
at::ScalarType out_dtype,
|
||||||
bool is_vnni);
|
bool is_vnni);
|
||||||
|
|
||||||
extern void
|
extern void
|
||||||
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale);
|
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
|
||||||
|
|
||||||
// NB: shapes in DeepDeek R1
|
// NB: shapes in DeepDeek R1
|
||||||
//
|
//
|
||||||
@@ -343,9 +343,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
|||||||
at::Tensor& cos_sin_cache,
|
at::Tensor& cos_sin_cache,
|
||||||
double eps,
|
double eps,
|
||||||
bool use_int8_w8a8,
|
bool use_int8_w8a8,
|
||||||
std::optional<at::Tensor>& q_a_proj_scale,
|
std::optional<at::Tensor> q_a_proj_scale,
|
||||||
std::optional<at::Tensor>& q_b_proj_scale,
|
std::optional<at::Tensor> q_b_proj_scale,
|
||||||
std::optional<at::Tensor>& kv_a_proj_scale,
|
std::optional<at::Tensor> kv_a_proj_scale,
|
||||||
bool is_vnni) {
|
bool is_vnni) {
|
||||||
RECORD_FUNCTION(
|
RECORD_FUNCTION(
|
||||||
"sgl-kernel::qkv_proj_with_rope",
|
"sgl-kernel::qkv_proj_with_rope",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#include <torch/torch.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
#include "sgl_kernel_ops.h"
|
||||||
#include "shm.h"
|
#include "shm.h"
|
||||||
|
|
||||||
// silu_and_mul
|
// silu_and_mul
|
||||||
@@ -85,7 +86,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight);
|
|||||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);
|
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);
|
||||||
|
|
||||||
// gemm
|
// gemm
|
||||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni);
|
at::Tensor
|
||||||
|
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
|
||||||
|
|
||||||
// igemm
|
// igemm
|
||||||
at::Tensor int8_scaled_mm_cpu(
|
at::Tensor int8_scaled_mm_cpu(
|
||||||
@@ -93,7 +95,7 @@ at::Tensor int8_scaled_mm_cpu(
|
|||||||
at::Tensor& mat2,
|
at::Tensor& mat2,
|
||||||
at::Tensor& scales1,
|
at::Tensor& scales1,
|
||||||
at::Tensor& scales2,
|
at::Tensor& scales2,
|
||||||
std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
at::ScalarType out_dtype,
|
at::ScalarType out_dtype,
|
||||||
bool is_vnni);
|
bool is_vnni);
|
||||||
|
|
||||||
@@ -103,7 +105,7 @@ at::Tensor fp8_scaled_mm_cpu(
|
|||||||
at::Tensor& mat2,
|
at::Tensor& mat2,
|
||||||
at::Tensor& scales2,
|
at::Tensor& scales2,
|
||||||
std::vector<int64_t> block_size,
|
std::vector<int64_t> block_size,
|
||||||
std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
at::ScalarType out_dtype,
|
at::ScalarType out_dtype,
|
||||||
bool is_vnni);
|
bool is_vnni);
|
||||||
|
|
||||||
@@ -112,12 +114,12 @@ at::Tensor int8_scaled_mm_with_quant(
|
|||||||
at::Tensor& mat1,
|
at::Tensor& mat1,
|
||||||
at::Tensor& mat2,
|
at::Tensor& mat2,
|
||||||
at::Tensor& scales2,
|
at::Tensor& scales2,
|
||||||
std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
at::ScalarType out_dtype,
|
at::ScalarType out_dtype,
|
||||||
bool is_vnni);
|
bool is_vnni);
|
||||||
|
|
||||||
// bmm
|
// bmm
|
||||||
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale);
|
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
|
||||||
|
|
||||||
// fused moe
|
// fused moe
|
||||||
at::Tensor fused_experts_cpu(
|
at::Tensor fused_experts_cpu(
|
||||||
@@ -128,10 +130,10 @@ at::Tensor fused_experts_cpu(
|
|||||||
at::Tensor& topk_ids,
|
at::Tensor& topk_ids,
|
||||||
bool inplace,
|
bool inplace,
|
||||||
bool use_int8_w8a8,
|
bool use_int8_w8a8,
|
||||||
std::optional<at::Tensor>& w1_scale,
|
const std::optional<at::Tensor>& w1_scale,
|
||||||
std::optional<at::Tensor>& w2_scale,
|
const std::optional<at::Tensor>& w2_scale,
|
||||||
std::optional<at::Tensor>& a1_scale,
|
const std::optional<at::Tensor>& a1_scale,
|
||||||
std::optional<at::Tensor>& a2_scale,
|
const std::optional<at::Tensor>& a2_scale,
|
||||||
bool is_vnni);
|
bool is_vnni);
|
||||||
|
|
||||||
at::Tensor shared_expert_cpu(
|
at::Tensor shared_expert_cpu(
|
||||||
@@ -143,11 +145,11 @@ at::Tensor shared_expert_cpu(
|
|||||||
bool inplace,
|
bool inplace,
|
||||||
bool use_int8_w8a8,
|
bool use_int8_w8a8,
|
||||||
bool use_fp8_w8a16,
|
bool use_fp8_w8a16,
|
||||||
std::optional<at::Tensor>& w1_scale,
|
const std::optional<at::Tensor>& w1_scale,
|
||||||
std::optional<at::Tensor>& w2_scale,
|
const std::optional<at::Tensor>& w2_scale,
|
||||||
std::optional<std::vector<int64_t>> block_size,
|
const std::optional<std::vector<int64_t>> block_size,
|
||||||
std::optional<at::Tensor>& a1_scale,
|
const std::optional<at::Tensor>& a1_scale,
|
||||||
std::optional<at::Tensor>& a2_scale,
|
const std::optional<at::Tensor>& a2_scale,
|
||||||
bool is_vnni);
|
bool is_vnni);
|
||||||
|
|
||||||
// weight absorption
|
// weight absorption
|
||||||
@@ -163,80 +165,130 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
|||||||
at::Tensor& cos_sin_cache,
|
at::Tensor& cos_sin_cache,
|
||||||
double eps,
|
double eps,
|
||||||
bool use_int8_w8a8,
|
bool use_int8_w8a8,
|
||||||
std::optional<at::Tensor>& q_a_proj_scale,
|
std::optional<at::Tensor> q_a_proj_scale,
|
||||||
std::optional<at::Tensor>& q_b_proj_scale,
|
std::optional<at::Tensor> q_b_proj_scale,
|
||||||
std::optional<at::Tensor>& kv_a_proj_scale,
|
std::optional<at::Tensor> kv_a_proj_scale,
|
||||||
bool is_vnni);
|
bool is_vnni);
|
||||||
|
|
||||||
// shared memory init
|
// shared memory init
|
||||||
void initialize(int size, int rank);
|
void initialize(int64_t size, int64_t rank);
|
||||||
|
|
||||||
// shared mmeory all_reduce
|
// shared mmeory all_reduce
|
||||||
void shm_allreduce(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op);
|
void shm_allreduce(
|
||||||
|
at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ReduceOp> op);
|
||||||
|
|
||||||
// shared memory all_gather
|
// shared memory all_gather
|
||||||
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim);
|
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int64_t dim);
|
||||||
|
|
||||||
// rope
|
// rope
|
||||||
std::tuple<at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor>
|
||||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
|
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||||
// activation
|
// activation
|
||||||
m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU");
|
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
|
||||||
|
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
m.def("rmsnorm_cpu", &rmsnorm_cpu, "Root mean square normalization for CPU");
|
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||||
m.def("fused_add_rmsnorm_cpu", &fused_add_rmsnorm_cpu, "Fused add root mean square normalization for CPU");
|
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
||||||
|
m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||||
|
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
||||||
|
|
||||||
// topk
|
// topk
|
||||||
m.def("grouped_topk_cpu", &grouped_topk_cpu, "Grouped TopK for CPU");
|
m.def(
|
||||||
|
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
|
||||||
|
"int topk_group) -> (Tensor, Tensor)");
|
||||||
|
m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu);
|
||||||
|
|
||||||
// biased group topk
|
// biased group topk
|
||||||
m.def("biased_grouped_topk_cpu", &biased_grouped_topk_cpu, "Biased Grouped TopK for CPU");
|
m.def(
|
||||||
|
"biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool "
|
||||||
|
"renormalize, int num_expert_group, int topk_group) -> (Tensor, Tensor)");
|
||||||
|
m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu);
|
||||||
|
|
||||||
// decode
|
// decode
|
||||||
m.def("decode_attention_cpu", &decode_attention_cpu, "Attention decoding for CPU");
|
m.def(
|
||||||
|
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
|
||||||
|
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()");
|
||||||
|
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
|
||||||
|
|
||||||
// extend
|
// extend
|
||||||
m.def("extend_attention_cpu", &extend_attention_cpu, "Attention extend for CPU");
|
m.def(
|
||||||
|
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, "
|
||||||
|
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
|
||||||
|
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
|
||||||
|
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
|
||||||
|
|
||||||
// weight prepack
|
// weight prepack
|
||||||
m.def("convert_weight_packed", &convert_weight_packed, "prepack weight to vnni format for intel AMX");
|
m.def("convert_weight_packed(Tensor weight) -> Tensor");
|
||||||
|
m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
|
||||||
|
|
||||||
// quant
|
// quant
|
||||||
m.def("per_token_quant_int8_cpu", &per_token_quant_int8_cpu, "dynamic quantization for CPU");
|
m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)");
|
||||||
|
m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu);
|
||||||
|
|
||||||
// gemm
|
// gemm
|
||||||
m.def("weight_packed_linear", &weight_packed_linear, "weight packed linear for intel AMX");
|
m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor");
|
||||||
|
m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
|
||||||
|
|
||||||
// igemm
|
// igemm
|
||||||
m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX");
|
m.def(
|
||||||
|
"int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType "
|
||||||
|
"out_dtype, bool is_vnni) -> Tensor");
|
||||||
|
m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu);
|
||||||
|
|
||||||
// fp8 gemm
|
// fp8 gemm
|
||||||
m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX");
|
m.def(
|
||||||
|
"fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType "
|
||||||
|
"out_dtype, bool is_vnni) -> Tensor");
|
||||||
|
m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
|
||||||
|
|
||||||
// quant + igemm
|
// quant + igemm
|
||||||
m.def(
|
m.def(
|
||||||
"int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX");
|
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool "
|
||||||
|
"is_vnni) -> Tensor");
|
||||||
|
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
|
||||||
|
|
||||||
// bmm
|
// bmm
|
||||||
m.def("bmm_cpu", &bmm_cpu, "bmm kernel for intel AMX");
|
m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
|
||||||
|
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
|
||||||
|
|
||||||
// moe
|
// moe
|
||||||
m.def("fused_experts_cpu", &fused_experts_cpu, "fused moe kernel for CPU");
|
m.def(
|
||||||
|
"fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
|
||||||
|
"inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool "
|
||||||
|
"is_vnni) -> Tensor");
|
||||||
|
m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||||
|
|
||||||
// weight absorption
|
// weight absorption
|
||||||
m.def("qkv_proj_with_rope", &qkv_proj_with_rope, "fused qkv projection kernel with weight absorption for intel AMX");
|
m.def(
|
||||||
|
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
|
||||||
|
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
|
||||||
|
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
|
||||||
|
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)");
|
||||||
|
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
|
||||||
|
|
||||||
// shared expert
|
// shared expert
|
||||||
m.def("shared_expert_cpu", &shared_expert_cpu, "shared expert kernel for CPU");
|
m.def(
|
||||||
|
"shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float "
|
||||||
|
"routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? "
|
||||||
|
"w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor");
|
||||||
|
m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu);
|
||||||
|
|
||||||
// all reduce
|
// all reduce
|
||||||
m.def("initialize", &initialize, "shared memory initialization for CPU");
|
m.def("initialize(int size, int rank) -> ()");
|
||||||
m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU");
|
m.impl("initialize", torch::kCPU, &initialize);
|
||||||
m.def("shm_allgather", &shm_allgather, "low latency all_gather implementation for CPU");
|
m.def(
|
||||||
|
"shm_allreduce(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, "
|
||||||
|
"__torch__.torch.classes.c10d.ReduceOp reduce_op) -> ()");
|
||||||
|
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||||
|
m.def("shm_allgather(Tensor data, __torch__.torch.classes.c10d.ProcessGroup process_group, int dim) -> Tensor");
|
||||||
|
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||||
|
|
||||||
// rope
|
// rope
|
||||||
m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU");
|
m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)");
|
||||||
|
m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
@@ -34,7 +34,3 @@ exclude = [
|
|||||||
cmake.source-dir = "csrc/cpu"
|
cmake.source-dir = "csrc/cpu"
|
||||||
cmake.build-type = "Release"
|
cmake.build-type = "Release"
|
||||||
minimum-version = "build-system.requires"
|
minimum-version = "build-system.requires"
|
||||||
|
|
||||||
wheel.py-api = "cp39"
|
|
||||||
wheel.license-files = []
|
|
||||||
wheel.packages = ["python/sgl_kernel"]
|
|
||||||
|
|||||||
@@ -50,7 +50,9 @@ def _get_version():
|
|||||||
cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1"
|
cpu_fp8_ftz = os.getenv("SGLANG_CPU_FP8_CVT_FTZ", "1") == "1"
|
||||||
|
|
||||||
operator_namespace = "sgl_kernel"
|
operator_namespace = "sgl_kernel"
|
||||||
include_dirs = []
|
include_dirs = [
|
||||||
|
"../../include",
|
||||||
|
]
|
||||||
|
|
||||||
sources = [
|
sources = [
|
||||||
"csrc/cpu/activation.cpp",
|
"csrc/cpu/activation.cpp",
|
||||||
@@ -99,7 +101,7 @@ ext_modules = [
|
|||||||
extra_compile_args=extra_compile_args,
|
extra_compile_args=extra_compile_args,
|
||||||
libraries=libraries,
|
libraries=libraries,
|
||||||
extra_link_args=extra_link_args,
|
extra_link_args=extra_link_args,
|
||||||
py_limited_api=True,
|
py_limited_api=False,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,10 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
# TODO: use interface in cpu.py
|
||||||
|
import sgl_kernel
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# TODO: use interface in cpu.py
|
|
||||||
from sgl_kernel.common_ops import (
|
|
||||||
convert_weight_packed,
|
|
||||||
fp8_scaled_mm_cpu,
|
|
||||||
int8_scaled_mm_cpu,
|
|
||||||
int8_scaled_mm_with_quant,
|
|
||||||
per_token_quant_int8_cpu,
|
|
||||||
weight_packed_linear,
|
|
||||||
)
|
|
||||||
from utils import (
|
from utils import (
|
||||||
convert_weight,
|
convert_weight,
|
||||||
native_w8a8_per_token_matmul,
|
native_w8a8_per_token_matmul,
|
||||||
@@ -58,10 +50,14 @@ class TestGemm(CustomTestCase):
|
|||||||
|
|
||||||
ref = ref.bfloat16()
|
ref = ref.bfloat16()
|
||||||
|
|
||||||
out = weight_packed_linear(mat1, mat2, bias if has_bias else None, False)
|
out = torch.ops.sgl_kernel.weight_packed_linear(
|
||||||
|
mat1, mat2, bias if has_bias else None, False
|
||||||
|
)
|
||||||
|
|
||||||
packed_mat2 = convert_weight_packed(mat2)
|
packed_mat2 = torch.ops.sgl_kernel.convert_weight_packed(mat2)
|
||||||
out2 = weight_packed_linear(mat1, packed_mat2, bias if has_bias else None, True)
|
out2 = torch.ops.sgl_kernel.weight_packed_linear(
|
||||||
|
mat1, packed_mat2, bias if has_bias else None, True
|
||||||
|
)
|
||||||
|
|
||||||
atol = rtol = precision[ref.dtype]
|
atol = rtol = precision[ref.dtype]
|
||||||
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
|
self.assertTrue(torch.allclose(ref, out, atol=atol, rtol=rtol))
|
||||||
@@ -100,14 +96,14 @@ class TestGemm(CustomTestCase):
|
|||||||
|
|
||||||
atol = rtol = precision[ref_out.dtype]
|
atol = rtol = precision[ref_out.dtype]
|
||||||
|
|
||||||
Aq2, As2 = per_token_quant_int8_cpu(A)
|
Aq2, As2 = torch.ops.sgl_kernel.per_token_quant_int8_cpu(A)
|
||||||
out = int8_scaled_mm_cpu(
|
out = torch.ops.sgl_kernel.int8_scaled_mm_cpu(
|
||||||
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
|
Aq2, Bq, As2, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
# test the fused version
|
# test the fused version
|
||||||
fused_out = int8_scaled_mm_with_quant(
|
fused_out = torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
||||||
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
|
A, Bq, Bs, bias if has_bias else None, torch.bfloat16, False
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
|
self.assertTrue(torch.allclose(ref_out, fused_out, atol=atol, rtol=rtol))
|
||||||
@@ -157,9 +153,9 @@ class TestGemm(CustomTestCase):
|
|||||||
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
|
ref = torch.matmul(data.to(A_dtype), dq_weight.T)
|
||||||
|
|
||||||
if prepack:
|
if prepack:
|
||||||
fp8_weight = convert_weight_packed(fp8_weight)
|
fp8_weight = torch.ops.sgl_kernel.convert_weight_packed(fp8_weight)
|
||||||
|
|
||||||
opt = fp8_scaled_mm_cpu(
|
opt = torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
||||||
data,
|
data,
|
||||||
fp8_weight,
|
fp8_weight,
|
||||||
scales,
|
scales,
|
||||||
|
|||||||
@@ -2,12 +2,10 @@ import itertools
|
|||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
# TODO: use interface in cpu.py
|
||||||
|
import sgl_kernel
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# TODO: use interface in cpu.py
|
|
||||||
from sgl_kernel.common_ops import convert_weight_packed
|
|
||||||
from sgl_kernel.common_ops import shared_expert_cpu as shared_expert
|
|
||||||
from utils import (
|
from utils import (
|
||||||
BLOCK_K,
|
BLOCK_K,
|
||||||
BLOCK_N,
|
BLOCK_N,
|
||||||
@@ -55,7 +53,7 @@ class TestSharedExpert(CustomTestCase):
|
|||||||
fused_output.float(),
|
fused_output.float(),
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
).to(dtype=dtype)
|
).to(dtype=dtype)
|
||||||
res = shared_expert(
|
res = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
@@ -113,7 +111,7 @@ class TestSharedExpert(CustomTestCase):
|
|||||||
fused_output.float(),
|
fused_output.float(),
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
).to(dtype=dtype)
|
).to(dtype=dtype)
|
||||||
res2 = shared_expert(
|
res2 = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||||
hidden_states2,
|
hidden_states2,
|
||||||
w1_q,
|
w1_q,
|
||||||
w2_q,
|
w2_q,
|
||||||
@@ -181,9 +179,9 @@ class TestSharedExpert(CustomTestCase):
|
|||||||
ref_out = shared_out + fused_out.float() * routed_scaling_factor
|
ref_out = shared_out + fused_out.float() * routed_scaling_factor
|
||||||
ref_out = ref_out.to(dtype=dtype)
|
ref_out = ref_out.to(dtype=dtype)
|
||||||
|
|
||||||
w1 = convert_weight_packed(w1) # [2N, K]
|
w1 = torch.ops.sgl_kernel.convert_weight_packed(w1) # [2N, K]
|
||||||
w2 = convert_weight_packed(w2) # [K, N]
|
w2 = torch.ops.sgl_kernel.convert_weight_packed(w2) # [K, N]
|
||||||
out = shared_expert(
|
out = torch.ops.sgl_kernel.shared_expert_cpu(
|
||||||
a2,
|
a2,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
|
|||||||
Reference in New Issue
Block a user