Reorganize c++ source files in sgl-kernel with multiple folders (#4025)
This commit is contained in:
@@ -80,6 +80,12 @@ nvcc_flags = [
|
||||
"-std=c++17",
|
||||
"-use_fast_math",
|
||||
"-DFLASHINFER_ENABLE_F16",
|
||||
"-DCUTLASS_VERSIONS_GENERATED",
|
||||
"-DCUTE_USE_PACKED_TUPLE=1",
|
||||
"-DCUTLASS_TEST_LEVEL=0",
|
||||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
|
||||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
|
||||
"--ptxas-options=-v",
|
||||
"-Xcompiler=-Wconversion",
|
||||
"-Xcompiler=-fno-strict-aliasing",
|
||||
]
|
||||
@@ -91,18 +97,18 @@ nvcc_flags_fp8 = [
|
||||
|
||||
sources = [
|
||||
"src/sgl-kernel/torch_extension.cc",
|
||||
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu",
|
||||
"src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu",
|
||||
"src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu",
|
||||
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/speculative/eagle_utils.cu",
|
||||
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
|
||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/eagle_utils.cu",
|
||||
"src/sgl-kernel/csrc/speculative_sampling.cu",
|
||||
"src/sgl-kernel/csrc/per_token_group_quant_fp8.cu",
|
||||
"src/sgl-kernel/csrc/cublas_grouped_gemm.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||
"3rdparty/flashinfer/csrc/norm.cu",
|
||||
|
||||
@@ -43,8 +43,8 @@ include_dirs = [
|
||||
|
||||
sources = [
|
||||
"src/sgl-kernel/torch_extension_rocm.cc",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/custom_all_reduce.hip",
|
||||
"src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip",
|
||||
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
|
||||
]
|
||||
|
||||
cxx_flags = ["-O3"]
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <speculative_sampling.cuh>
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
#include "speculative_sampling.cuh"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
@@ -35,7 +35,24 @@ limitations under the License.
|
||||
}
|
||||
|
||||
using fptr_t = int64_t;
|
||||
|
||||
/*
|
||||
* From csrc/activation
|
||||
*/
|
||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
|
||||
int64_t cuda_stream);
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
// ROCM custom allreduce
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector<std::string>& handles,
|
||||
const std::vector<int64_t>& offsets, int64_t rank, bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
|
||||
@@ -50,7 +67,7 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
||||
torch::Tensor allocate_meta_buffer(int64_t size);
|
||||
torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp);
|
||||
#else
|
||||
// trt_reduce
|
||||
// TRTLLM custom allreduce
|
||||
fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector<fptr_t>& buffers,
|
||||
const std::vector<fptr_t>& tmp_result_buffers, const std::vector<fptr_t>& barrier_in,
|
||||
const std::vector<fptr_t>& barrier_out);
|
||||
@@ -61,94 +78,34 @@ void register_graph_buffers(fptr_t _fa, const std::vector<std::vector<int64_t>>&
|
||||
const std::vector<std::vector<int64_t>>& offsets);
|
||||
#endif
|
||||
|
||||
// moe_align_block_size
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype);
|
||||
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size,
|
||||
double eps, double fp8_min, double fp8_max);
|
||||
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
|
||||
const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype,
|
||||
int64_t cublas_handle, int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
|
||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
|
||||
|
||||
// int8_scaled_mm
|
||||
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
|
||||
// fp8_scaled_mm
|
||||
torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
|
||||
// fp8_blockwise_scaled_mm
|
||||
torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Dtype& out_dtype);
|
||||
|
||||
// lightning_attention_decode
|
||||
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
// rms norm
|
||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
|
||||
// fused rms norm
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
|
||||
|
||||
// gemma rms norm
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
|
||||
// fused gemma rms norm
|
||||
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
|
||||
int64_t cuda_stream);
|
||||
|
||||
// silu and mul
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
// gelu tanh and mul
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
// gelu and mul
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
// bmm fp8
|
||||
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
|
||||
|
||||
// min p sampling from probs
|
||||
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
||||
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
// top k renorm probs
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val, int64_t cuda_stream);
|
||||
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
// wrapper for binding
|
||||
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
|
||||
int64_t cuda_stream) {
|
||||
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
||||
}
|
||||
|
||||
// top p renorm probs
|
||||
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val, int64_t cuda_stream);
|
||||
|
||||
// top k top p sampling from probs
|
||||
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
||||
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
// top p sampling from probs
|
||||
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
|
||||
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave,
|
||||
int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index,
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates, at::Tensor retrive_index,
|
||||
@@ -165,11 +122,40 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te
|
||||
at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk,
|
||||
int64_t depth, int64_t draft_token_num);
|
||||
|
||||
// sgl_per_token_group_quant_fp8
|
||||
void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size,
|
||||
double eps, double fp8_min, double fp8_max);
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
|
||||
void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
||||
std::optional<at::Tensor> maybe_min_p_arr, double min_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
// top k renorm probs
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr,
|
||||
unsigned int top_k_val, int64_t cuda_stream);
|
||||
// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension.
|
||||
inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs,
|
||||
std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val,
|
||||
int64_t cuda_stream) {
|
||||
top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast<unsigned int>(top_k_val), cuda_stream);
|
||||
}
|
||||
void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr,
|
||||
double top_p_val, int64_t cuda_stream);
|
||||
void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples,
|
||||
at::Tensor success, std::optional<at::Tensor> maybe_top_k_arr, double top_k_val,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success,
|
||||
std::optional<at::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic,
|
||||
int64_t cuda_stream);
|
||||
void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope,
|
||||
at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave,
|
||||
int64_t cuda_stream);
|
||||
|
||||
// cublas grouped gemm
|
||||
void cublas_grouped_gemm(const std::vector<torch::Tensor>& inputs, const std::vector<torch::Tensor>& weights,
|
||||
const std::vector<torch::Tensor>& outputs, const torch::Dtype& out_dtype,
|
||||
int64_t cublas_handle, int64_t cuda_stream);
|
||||
/*
|
||||
* Other
|
||||
*/
|
||||
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
|
||||
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
|
||||
@@ -19,7 +19,33 @@ limitations under the License.
|
||||
#include "sgl_kernels_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
// trt_reduce
|
||||
/*
|
||||
* From csrc/activation
|
||||
*/
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
m.def(
|
||||
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
|
||||
"barrier_in, int[] barrier_out) -> int");
|
||||
@@ -36,108 +62,49 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
||||
m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers);
|
||||
|
||||
// moe_align_block_size
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// int8_scaled_mm
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
m.def(
|
||||
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
"bias) -> Tensor");
|
||||
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
|
||||
|
||||
// fp8_scaled_mm
|
||||
m.def(
|
||||
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
"bias) -> Tensor");
|
||||
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
|
||||
|
||||
// fp8_blockwise_scaled_mm
|
||||
m.def(
|
||||
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
|
||||
"Tensor");
|
||||
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
|
||||
|
||||
// lightning_attention_decode
|
||||
m.def(
|
||||
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float fp8_min, float fp8_max) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
|
||||
|
||||
m.def(
|
||||
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
|
||||
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
m.def(
|
||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||
"new_kv) -> ()");
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
|
||||
// rms norm
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
// fused rms norm
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||
|
||||
// gemma rms norm
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||
|
||||
// fused gemma rms norm
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||
|
||||
// silu and mul
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
// gelu tanh and mul
|
||||
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
|
||||
|
||||
// gelu and mul
|
||||
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
|
||||
|
||||
// bmm fp8
|
||||
m.def(
|
||||
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
|
||||
"cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||
|
||||
// min p sampling from probs
|
||||
m.def(
|
||||
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
|
||||
"min_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||
|
||||
// top k renorm probs
|
||||
m.def(
|
||||
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
|
||||
|
||||
// top p renorm probs
|
||||
m.def(
|
||||
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||
|
||||
// top k top p sampling from probs
|
||||
m.def(
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||
|
||||
// top p sampling from probs
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
// apply rope with cos sin cache
|
||||
m.def(
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
// tree spec decode
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
m.def(
|
||||
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
@@ -145,7 +112,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
"bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
||||
|
||||
// eagle build tree
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
|
||||
@@ -153,24 +119,55 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
"int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
|
||||
// eagle build tree
|
||||
m.def(
|
||||
"build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
|
||||
"int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel);
|
||||
|
||||
// per_token_group_quant_fp8
|
||||
/*
|
||||
* From FlashInfer
|
||||
*/
|
||||
m.def(
|
||||
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
|
||||
" float eps, float fp8_min, float fp8_max) -> ()");
|
||||
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
|
||||
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
|
||||
"cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||
|
||||
// cublas grouped gemm
|
||||
m.def(
|
||||
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
|
||||
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
|
||||
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
|
||||
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
|
||||
"min_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper);
|
||||
|
||||
m.def(
|
||||
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
|
||||
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
|
||||
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
|
||||
|
||||
/*
|
||||
* Other
|
||||
*/
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_kernels)
|
||||
|
||||
@@ -19,7 +19,9 @@ limitations under the License.
|
||||
#include "sgl_kernels_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
// Custom all-reduce kernels
|
||||
/*
|
||||
* From csrc/allreduce
|
||||
*/
|
||||
m.def(
|
||||
"init_custom_ar(Tensor meta, Tensor rank_data, "
|
||||
"str[] handles, int[] offsets, int rank, "
|
||||
@@ -45,12 +47,16 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
|
||||
m.def("allocate_meta_buffer", &allocate_meta_buffer);
|
||||
m.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer);
|
||||
|
||||
m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle);
|
||||
m.impl("get_meta_buffer_ipc_handle", torch::kCPU, &get_meta_buffer_ipc_handle);
|
||||
|
||||
// moe_align_block_size
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
m.def(
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import socket
|
||||
import time
|
||||
import unittest
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@@ -115,7 +114,7 @@ class TestCustomAllReduce(unittest.TestCase):
|
||||
self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group)
|
||||
self.rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
|
||||
)
|
||||
|
||||
self.custom_ptr = custom_ops.init_custom_reduce(
|
||||
@@ -148,7 +147,7 @@ class TestCustomAllReduce(unittest.TestCase):
|
||||
self.vllm_max_size, group=group
|
||||
)
|
||||
self.vllm_rank_data = torch.empty(
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}")
|
||||
8 * 1024 * 1024, dtype=torch.uint8, device=torch.device("cuda:0")
|
||||
)
|
||||
self.vllm_ptr = vllm_ops.init_custom_ar(
|
||||
self.vllm_meta_ptrs, self.vllm_rank_data, rank, True
|
||||
@@ -171,8 +170,7 @@ class TestCustomAllReduce(unittest.TestCase):
|
||||
|
||||
@staticmethod
|
||||
def init_distributed_env(world_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
ranks = [i for i in range(world_size)]
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
@@ -234,8 +232,8 @@ class TestCustomAllReduce(unittest.TestCase):
|
||||
if rank == 0:
|
||||
logger.warning(
|
||||
f"test_size = {sz}, world_size = {world_size}, "
|
||||
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms,"
|
||||
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms"
|
||||
f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms, "
|
||||
f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms "
|
||||
)
|
||||
|
||||
self.free_custom_allreduce(group)
|
||||
Reference in New Issue
Block a user