From d052f4c8a9fb7e135ca0f0b09f6feead93db9e01 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 7 Mar 2025 20:21:08 -0800 Subject: [PATCH] New clang format for sgl kernel (#4194) --- python/upload_pypi.sh | 6 - sgl-kernel/.clang-format | 7 + .../activation/fused_add_rms_norm_kernel.cu | 13 +- .../csrc/allreduce/custom_all_reduce_hip.cuh | 116 ++-- .../csrc/allreduce/trt_reduce_internal.cu | 43 +- .../csrc/allreduce/trt_reduce_kernel.cu | 30 +- .../lightning_attention_decode_kernel.cu | 46 +- .../epilogue/epilogue_per_row_per_col_scale.h | 63 +- .../gemm/dispatch_policy.hpp | 20 +- .../gemm/gemm_universal_base_compat.h | 22 +- .../gemm/gemm_with_epilogue_visitor.h | 72 ++- .../csrc/gemm/cublas_grouped_gemm.cu | 64 +- .../csrc/gemm/fp8_blockwise_gemm_kernel.cu | 75 ++- .../sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu | 594 ++++++++++++------ .../sgl-kernel/csrc/gemm/int8_gemm_kernel.cu | 404 ++++++++---- .../csrc/gemm/per_tensor_quant_fp8.cu | 17 +- .../csrc/gemm/per_token_group_quant_fp8.cu | 33 +- .../csrc/gemm/per_token_quant_fp8.cu | 16 +- .../sgl-kernel/csrc/moe/moe_align_kernel.cu | 53 +- .../csrc/speculative/eagle_utils.cu | 86 ++- .../csrc/speculative/speculative_sampling.cu | 46 +- .../csrc/speculative/speculative_sampling.cuh | 95 ++- .../src/sgl-kernel/include/sgl_kernels_ops.h | 242 +++++-- .../include/trt_reduce_internal.cuh | 4 +- sgl-kernel/src/sgl-kernel/include/utils.h | 1 - 25 files changed, 1486 insertions(+), 682 deletions(-) delete mode 100644 python/upload_pypi.sh diff --git a/python/upload_pypi.sh b/python/upload_pypi.sh deleted file mode 100644 index 35616e1da..000000000 --- a/python/upload_pypi.sh +++ /dev/null @@ -1,6 +0,0 @@ -cp ../README.md ../LICENSE . -rm -rf dist -python3 -m build -python3 -m twine upload dist/* - -rm -rf README.md LICENSE diff --git a/sgl-kernel/.clang-format b/sgl-kernel/.clang-format index 5e690c028..afbd654a7 100644 --- a/sgl-kernel/.clang-format +++ b/sgl-kernel/.clang-format @@ -6,3 +6,10 @@ DerivePointerAlignment: false PointerAlignment: Left NamespaceIndentation: None SortIncludes: true +AllowShortLoopsOnASingleLine: false +BinPackParameters: false # Prevents packing parameters in declarations +BinPackArguments: false # Prevents packing arguments in function calls +AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis +AlignOperands: Align # Aligns arguments vertically +PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument +PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name diff --git a/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu index a4ae14ae5..41f4d2e70 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu @@ -41,10 +41,15 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T // support float16, bfloat16 and float32 DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { cudaError_t status = norm::FusedAddRMSNorm( - static_cast(input.data_ptr()), static_cast(residual.data_ptr()), - static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); + static_cast(input.data_ptr()), + static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), + batch_size, + hidden_size, + eps, + torch_current_stream); + TORCH_CHECK( + status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh b/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh index 06173bc42..7baf5f01e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/allreduce/custom_all_reduce_hip.cuh @@ -153,19 +153,20 @@ DINLINE O downcast(array_t val) { // prior memory accesses. Note: volatile writes will not be reordered against // other volatile writes. template -DINLINE void start_sync(const RankSignals& sg, +DINLINE void start_sync( + const RankSignals& sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal* self_sg, - int rank) { + Signal* self_sg, + int rank) { #ifdef USE_ROCM uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, - __MEMORY_SCOPE_SYSTEM); + __scoped_atomic_store_n( + &sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) < flag) @@ -193,12 +194,13 @@ DINLINE void start_sync(const RankSignals& sg, // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template -DINLINE void end_sync(const RankSignals& sg, +DINLINE void end_sync( + const RankSignals& sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal* self_sg, - int rank) { + Signal* self_sg, + int rank) { #ifdef USE_ROCM __syncthreads(); // eliminate the case that prior writes are not visible after signals become @@ -209,11 +211,16 @@ DINLINE void end_sync(const RankSignals& sg, if (threadIdx.x < ngpus) { // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, - final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, __MEMORY_SCOPE_SYSTEM); + __scoped_atomic_store_n( + &sg.signals[threadIdx.x]->end[blockIdx.x][rank], + flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, + __MEMORY_SCOPE_SYSTEM); // wait until we got true from all ranks - while (__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], - final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, __MEMORY_SCOPE_DEVICE) < flag) + while (__scoped_atomic_load_n( + &self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, + __MEMORY_SCOPE_DEVICE) < flag) ; } __syncthreads(); @@ -251,12 +258,16 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { } template -__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg, +__global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage( + RankData* _dp, + RankSignals sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal* self_sg, - T* __restrict__ result, int rank, int size) { + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same @@ -280,12 +291,16 @@ DINLINE P* get_tmp_buf(volatile Signal* sg) { } template -__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg, +__global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage( + RankData* _dp, + RankSignals sg, #ifndef USE_ROCM - volatile + volatile #endif - Signal* self_sg, - T* __restrict__ result, int rank, int size) { + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -357,8 +372,14 @@ class CustomAllreduce { * note: this class does not own any device memory. Any required buffers * are passed in from the constructor */ - CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, const hipIpcMemHandle_t* handles, - const std::vector& offsets, int rank, bool full_nvlink = true) + CustomAllreduce( + Signal* meta, + void* rank_data, + size_t rank_data_sz, + const hipIpcMemHandle_t* handles, + const std::vector& offsets, + int rank, + bool full_nvlink = true) : rank_(rank), world_size_(offsets.size()), full_nvlink_(full_nvlink), @@ -382,8 +403,8 @@ class CustomAllreduce { auto [it, new_handle] = ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { char* ipc_ptr; - CUDACHECK(hipIpcOpenMemHandle((void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), - hipIpcMemLazyEnablePeerAccess)); + CUDACHECK(hipIpcOpenMemHandle( + (void**)&ipc_ptr, *((const hipIpcMemHandle_t*)ipc_handle), hipIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } return it->second; @@ -399,13 +420,14 @@ class CustomAllreduce { void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address - if (hipPointerGetAttribute(&base_ptr, + if (hipPointerGetAttribute( + &base_ptr, #ifdef USE_ROCM - HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, + HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, #else - CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, #endif - (hipDeviceptr_t)ptr) != hipSuccess) + (hipDeviceptr_t)ptr) != hipSuccess) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(hipIpcGetMemHandle((hipIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); offsets[i] = ((char*)ptr) - ((char*)base_ptr); @@ -415,8 +437,8 @@ class CustomAllreduce { void check_rank_data_capacity(size_t num = 1) { if (d_rank_data_base_ + num > d_rank_data_end_) - throw std::runtime_error("Rank data buffer is overflowed by " + - std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + throw std::runtime_error( + "Rank data buffer is overflowed by " + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } void register_buffer(const std::vector& handles, const std::vector& offsets, void* self) { @@ -443,8 +465,8 @@ class CustomAllreduce { // rank 1 may get the same input address for the second allreduce, but rank 2 // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. - void register_graph_buffers(const std::vector& handles, - const std::vector>& offsets) { + void + register_graph_buffers(const std::vector& handles, const std::vector>& offsets) { auto num_buffers = graph_unreg_buffers_.size(); check_rank_data_capacity(num_buffers); std::vector rank_data(num_buffers); @@ -474,11 +496,17 @@ class CustomAllreduce { * will cause contention on NVLink bus. */ template - void allreduce(hipStream_t stream, T* input, T* output, int size, + void allreduce( + hipStream_t stream, + T* input, + T* output, + int size, #ifndef USE_ROCM - int threads = 512, int block_limit = 36){ + int threads = 512, + int block_limit = 36){ #else - int threads = 512, int block_limit = 16) { + int threads = 512, + int block_limit = 16) { #endif auto d = packed_t::P::size; if (size % d != 0) @@ -487,8 +515,8 @@ class CustomAllreduce { "of " + std::to_string(d)); if (block_limit > kMaxBlocks) - throw std::runtime_error("max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + - std::to_string(block_limit)); + throw std::runtime_error( + "max supported block limit is " + std::to_string(kMaxBlocks) + ". Got " + std::to_string(block_limit)); RankData* ptrs; hipStreamCaptureStatus status; @@ -499,17 +527,17 @@ class CustomAllreduce { } else { auto it = buffers_.find(input); if (it == buffers_.end()) - throw std::runtime_error("buffer address " + std::to_string(reinterpret_cast(input)) + - " is not registered!"); + throw std::runtime_error( + "buffer address " + std::to_string(reinterpret_cast(input)) + " is not registered!"); ptrs = it->second; } size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = ::min(block_limit, (size + threads - 1) / threads); -#define KL(ngpus, name) \ - hipLaunchKernelGGL((name), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, \ - size); +#define KL(ngpus, name) \ + hipLaunchKernelGGL( \ + (name), dim3(blocks), dim3(threads), 0, stream, ptrs, sg_, self_sg_, output, rank_, size); #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu index fa9e3a2c5..f1ee5d40e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu @@ -118,8 +118,13 @@ inline __device__ int4 add128b(T& a, T& b) { return c.packed; } -__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, - size_t const world_size, int const tidx, int const bidx) { +__inline__ __device__ void multi_gpu_barrier( + uint32_t** signals, + uint32_t const flag, + size_t const local_rank, + size_t const world_size, + int const tidx, + int const bidx) { // After this function, at least one block in each GPU has reached the barrier if (tidx < world_size) { // we can think of signals having the shape [world_size, world_size] @@ -143,8 +148,14 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const } template -__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, - size_t const world_size, int const tidx, int const bidx, int const grid_size) { +__inline__ __device__ void block_barrier( + uint32_t** signals, + uint32_t const flag, + size_t const local_rank, + size_t const world_size, + int const tidx, + int const bidx, + int const grid_size) { if constexpr (!start) { __syncthreads(); } @@ -227,8 +238,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc } } // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer - block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size); + block_barrier( + params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size); // Each block accumulates the values from the different GPUs on the same node. for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { @@ -341,8 +352,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc } } } - block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size); + block_barrier( + params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size); // Each block accumulates the values from the different GPUs on the same node. for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { @@ -372,8 +383,8 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc } } - block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, - bidx, grid_size); + block_barrier( + params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, grid_size); // Gather all needed elts from other intra-node ranks for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { @@ -459,8 +470,12 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar //////////////////////////////////////////////////////////////////////////////////////////////////// template -void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, - cudaStream_t stream) { +void dispatchARKernels( + AllReduceStrategyType algo, + AllReduceParams& param, + int blocks_per_grid, + int threads_per_block, + cudaStream_t stream) { switch (algo) { case AllReduceStrategyType::ONESHOT: { oneShotAllReduceKernel<<>>(param); @@ -505,8 +520,8 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategy CHECK_CUDA_SUCCESS(cudaGetLastError()); } -void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, - cudaStream_t stream) { +void trtCustomAllReduce( + AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream) { if (params.elts_total == 0) { return; } diff --git a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu index af129de52..5c8792556 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu @@ -29,9 +29,14 @@ using IPC_KEY = std::array; class AllReduceMeta { public: - AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, - const std::vector& tmp_result_buffers, const std::vector& barrier_in, - const std::vector& barrier_out) { + AllReduceMeta( + int64_t rank_id, + int64_t world_size, + torch::Tensor& rank_data, + const std::vector& buffers, + const std::vector& tmp_result_buffers, + const std::vector& barrier_in, + const std::vector& barrier_out) { this->rank_id = (int)rank_id; this->world_size = (int)world_size; this->barrier_in = barrier_in; @@ -86,9 +91,14 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; } -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, - const std::vector& tmp_result_buffers, const std::vector& barrier_in, - const std::vector& barrier_out) { +fptr_t init_custom_ar( + int64_t rank_id, + int64_t world_size, + torch::Tensor& rank_data, + const std::vector& buffers, + const std::vector& tmp_result_buffers, + const std::vector& barrier_in, + const std::vector& barrier_out) { auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out); return (fptr_t)m; } @@ -124,8 +134,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { char* ipc_ptr; - CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), - cudaIpcMemLazyEnablePeerAccess)); + CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle( + (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } return it->second; @@ -138,8 +148,8 @@ char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { // rank 1 may get the same input address for the second allreduce, but rank 2 // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. -void register_graph_buffers(fptr_t _fa, const std::vector>& handles, - const std::vector>& offsets) { +void register_graph_buffers( + fptr_t _fa, const std::vector>& handles, const std::vector>& offsets) { AllReduceMeta* m = reinterpret_cast(_fa); std::vector handle_bytes; handle_bytes.reserve(handles.size()); diff --git a/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu index 02c50498e..f9d524f60 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu @@ -23,15 +23,18 @@ limitations under the License. #define THREADS_PER_BLOCK 128 template -__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] - const T* __restrict__ k, // [b, h, 1, d] - const T* __restrict__ v, // [b, h, 1, e] - const float* __restrict__ past_kv, // [b, h, d, e] - const float* __restrict__ slope, // [h, 1, 1] - T* __restrict__ output, // [b, h, 1, e] - float* __restrict__ new_kv, // [b, h, d, e] - const int batch_size, const int num_heads, const int qk_dim, - const int v_dim) { +__global__ void lightning_attention_decode_kernel( + const T* __restrict__ q, // [b, h, 1, d] + const T* __restrict__ k, // [b, h, 1, d] + const T* __restrict__ v, // [b, h, 1, e] + const float* __restrict__ past_kv, // [b, h, d, e] + const float* __restrict__ slope, // [h, 1, 1] + T* __restrict__ output, // [b, h, 1, e] + float* __restrict__ new_kv, // [b, h, d, e] + const int batch_size, + const int num_heads, + const int qk_dim, + const int v_dim) { extern __shared__ char smem[]; T* __restrict__ q_shared = reinterpret_cast(smem); T* __restrict__ k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); @@ -109,9 +112,14 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q, } } -void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, - const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, - torch::Tensor new_kv) { +void lightning_attention_decode( + const torch::Tensor& q, + const torch::Tensor& k, + const torch::Tensor& v, + const torch::Tensor& past_kv, + const torch::Tensor& slope, + torch::Tensor output, + torch::Tensor new_kv) { TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); @@ -131,8 +139,16 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); lightning_attention_decode_kernel<<>>( - q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(), - slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads, - qk_dim, v_dim); + q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + past_kv.data_ptr(), + slope.data_ptr(), + output.data_ptr(), + new_kv.data_ptr(), + batch_size, + num_heads, + qk_dim, + v_dim); })); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h index f5cd43815..9f85bee28 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -25,9 +25,15 @@ namespace cutlass { namespace epilogue { namespace threadblock { -template +template < + typename ThreadblockShape_, + int ThreadCount, + typename ScaleTileIterator_, + typename OutputTileIterator_, + typename ElementAccumulator_, + typename ElementCompute_, + typename ElementwiseFunctor_, + bool UseMasking_ = false> class EpilogueVisitorPerRowPerCol { public: using ThreadblockShape = ThreadblockShape_; @@ -69,8 +75,11 @@ class EpilogueVisitorPerRowPerCol { Arguments(typename ElementwiseFunctor::Params elementwise_) : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} - Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_, - int64_t batch_stride_D_) + Arguments( + typename ElementwiseFunctor::Params elementwise_, + int64_t batch_stride_alpha_, + int64_t batch_stride_C_, + int64_t batch_stride_D_) : elementwise(elementwise_), batch_stride_alpha(batch_stride_alpha_), batch_stride_C(batch_stride_C_), @@ -131,17 +140,26 @@ class EpilogueVisitorPerRowPerCol { public: CUTLASS_DEVICE - EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, - cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, - typename ScaleTileIterator::Params params_alpha_col, - typename OutputTileIterator::Params params_C, - typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant, - bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row, - AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, - typename OutputTileIterator::Element* ptr_D, - cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), - int column_offset = 0, - cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + EpilogueVisitorPerRowPerCol( + Params const& params, + SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, + int thread_idx, + int warp_idx, + int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, + bool with_bias, + bool per_token_quant, + bool per_channel_quant, + AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) : params_(params), shared_storage_(shared_storage), extent_(problem_size), @@ -166,8 +184,9 @@ class EpilogueVisitorPerRowPerCol { /// Helper to indicate split-K behavior CUTLASS_DEVICE - void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) { ///< Total number of split-K slices + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices } /// Called to set the batch index @@ -251,8 +270,8 @@ class EpilogueVisitorPerRowPerCol { private: CUTLASS_DEVICE - ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col, - AlphaScaleElementType const& scale_row) { + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) { ComputeFragment result; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ComputeFragment::kElements; ++i) { @@ -263,8 +282,8 @@ class EpilogueVisitorPerRowPerCol { } CUTLASS_DEVICE - ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col, - AlphaScaleElementType const& scale_row) { + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) { ComputeFragment result; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < ComputeFragment::kElements; ++i) { diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp index 48b0ad949..f62b51ee7 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/dispatch_policy.hpp @@ -16,16 +16,20 @@ struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum : KernelT // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp // specialized dynamic schedule For FP8 kernels with Block Scaling -template , class KernelSchedule = KernelTmaWarpSpecialized, - int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, - // while zero-value `ScaleGranularityM` indicates that scaling - // granularity is `size<0>(TileShape_MNK{})` along M. - > +template < + int Stages_, + class ClusterShape_ = Shape<_1, _1, _1>, + class KernelSchedule = KernelTmaWarpSpecialized, + int ScaleGranularityM = 0 // `ScaleGranularityM` specifies scaling granularity along M, + // while zero-value `ScaleGranularityM` indicates that scaling + // granularity is `size<0>(TileShape_MNK{})` along M. + > struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8 : MainloopSm90TmaGmmaWarpSpecialized { - static_assert(cute::is_same_v>, - "KernelSchedule must be one of the warp specialized policies"); + static_assert( + cute:: + is_same_v>, + "KernelSchedule must be one of the warp specialized policies"); }; ////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h index 3de9ff078..b58d84318 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -159,8 +159,9 @@ class GemmUniversalBaseCompat { get_grid_shape_(grid_tiled_shape, gemm_k_size, args); dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" - << " result = {" << result << "}"); + CUTLASS_TRACE_HOST( + " grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); return result; } @@ -175,8 +176,8 @@ class GemmUniversalBaseCompat { CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); if (smem_size <= (48 << 10)) { - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, - GemmKernel::kThreadCount, smem_size); + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); if (result == cudaSuccess) { CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); @@ -184,12 +185,12 @@ class GemmUniversalBaseCompat { } } else { // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, - GemmKernel::kThreadCount, 0); + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); if (result != cudaSuccess) { - CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " - << cudaGetErrorString(result)); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); return -1; } @@ -226,8 +227,9 @@ class GemmUniversalBaseCompat { /// Initializes GEMM state from arguments. Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); + CUTLASS_TRACE_HOST( + "GemmUniversalBaseCompat::initialize() - workspace " << workspace + << ", stream: " << (stream ? "non-null" : "null")); size_t workspace_bytes = get_workspace_size(args); diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h index 11fc87250..905d11ba2 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -32,10 +32,11 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function + > struct GemmWithEpilogueVisitor { public: using Mma = Mma_; @@ -119,9 +120,15 @@ struct GemmWithEpilogueVisitor { Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} /// constructs an arguments structure - Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_, - TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, - typename EpilogueVisitor::Arguments epilogue_visitor_) + Arguments( + GemmCoord problem_size_, + TensorRefA ref_A_, + TensorRefB ref_B_, + TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, + TensorRefC ref_C_, + TensorRefC ref_D_, + typename EpilogueVisitor::Arguments epilogue_visitor_) : mode(GemmUniversalMode::kGemm), problem_size(problem_size_), batch_count(1), @@ -269,8 +276,9 @@ struct GemmWithEpilogueVisitor { isAMisaligned = problem_size.k() % kAlignmentA; } else if (platform::is_same::value) { isAMisaligned = problem_size.m() % kAlignmentA; - } else if (platform::is_same>::value || - platform::is_same>::value) { + } else if ( + platform::is_same>::value || + platform::is_same>::value) { isAMisaligned = problem_size.k() % kAlignmentA; } @@ -278,8 +286,9 @@ struct GemmWithEpilogueVisitor { isBMisaligned = problem_size.n() % kAlignmentB; } else if (platform::is_same::value) { isBMisaligned = problem_size.k() % kAlignmentB; - } else if (platform::is_same>::value || - platform::is_same>::value) { + } else if ( + platform::is_same>::value || + platform::is_same>::value) { isBMisaligned = problem_size.k() % kAlignmentB; } @@ -287,8 +296,9 @@ struct GemmWithEpilogueVisitor { isCMisaligned = problem_size.n() % kAlignmentC; } else if (platform::is_same::value) { isCMisaligned = problem_size.m() % kAlignmentC; - } else if (platform::is_same>::value || - platform::is_same>::value) { + } else if ( + platform::is_same>::value || + platform::is_same>::value) { isCMisaligned = problem_size.n() % kAlignmentC; } @@ -373,11 +383,11 @@ struct GemmWithEpilogueVisitor { int thread_idx = threadIdx.x; // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, - tb_offset_A); + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, - tb_offset_B); + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -409,8 +419,8 @@ struct GemmWithEpilogueVisitor { threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // assume identity swizzle - MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN); + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); @@ -423,11 +433,25 @@ struct GemmWithEpilogueVisitor { with_bias = false; } - EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(), - thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, - params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col, - params.ptr_C, params.ptr_D, threadblock_offset, - blockIdx.y * params.problem_size.m()); + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_alpha_col, + params.params_C, + params.params_D, + with_bias, + true, + true, + params.ptr_alpha_row, + params.ptr_alpha_col, + params.ptr_C, + params.ptr_D, + threadblock_offset, + blockIdx.y * params.problem_size.m()); if (params.mode == GemmUniversalMode::kGemm) { // Indicate which position in a serial reduction the output operator is currently updating diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu index ec899d330..d0a80c7bf 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu @@ -21,10 +21,13 @@ #include "utils.h" -static void check_group_count(const std::vector& inputs, const std::vector& weights, - const std::vector& outputs) { - TORCH_CHECK(((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), - "The group count of inputs, weights and outputs should be the same."); +static void check_group_count( + const std::vector& inputs, + const std::vector& weights, + const std::vector& outputs) { + TORCH_CHECK( + ((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), + "The group count of inputs, weights and outputs should be the same."); } static void check_device_dtype(const torch::Dtype& dtype, const std::vector& tensors) { @@ -68,21 +71,26 @@ static std::vector get_tensor_ptrs(const std::vector& tens static torch::Tensor create_ptr_pointer(const std::vector& ptrs, cudaStream_t stream) { auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA); torch::Tensor gpu_ptrs = torch::empty({static_cast(ptrs.size())}, options); - TORCH_CHECK(cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, - stream) == CUBLAS_STATUS_SUCCESS); + TORCH_CHECK( + cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) == + CUBLAS_STATUS_SUCCESS); return gpu_ptrs; } // We want compute input @ weight^T in row major // This is equivalent to computing weight @ input^T in col major // Cublas only accepts matrix in column major, so this arrangement is needed -void cublas_grouped_gemm(const std::vector& inputs, // b: (m, k) row major = (k, m) col major - const std::vector& weights, // a: (n, k) row major = (n, k)^T col major - const std::vector& outputs, // c: (m, n) row major = (n, m) col major - const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) { - TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, - "cublas grouped_gemm can" - "only be applied to float16 and bfloat16 dtype"); +void cublas_grouped_gemm( + const std::vector& inputs, // b: (m, k) row major = (k, m) col major + const std::vector& weights, // a: (n, k) row major = (n, k)^T col major + const std::vector& outputs, // c: (m, n) row major = (n, m) col major + const torch::Dtype& out_dtype, + int64_t cublas_handle, + int64_t cuda_stream) { + TORCH_CHECK( + out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, + "cublas grouped_gemm can" + "only be applied to float16 and bfloat16 dtype"); int group_count = inputs.size(); check_group_count(inputs, weights, outputs); @@ -133,16 +141,32 @@ void cublas_grouped_gemm(const std::vector& inputs, // b: (m, k torch::Tensor d_c = create_ptr_pointer(c_array, stream); #if defined CUDA_VERSION && CUDA_VERSION >= 12050 - auto status = cublasGemmGroupedBatchedEx(handle, transa_array.data(), transb_array.data(), m_array.data(), - n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(), - cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type, - ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type, - ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F); + auto status = cublasGemmGroupedBatchedEx( + handle, + transa_array.data(), + transb_array.data(), + m_array.data(), + n_array.data(), + k_array.data(), + alpha_array.data(), + (void**)d_a.data_ptr(), + cuda_data_type, + lda_array.data(), + (void**)d_b.data_ptr(), + cuda_data_type, + ldb_array.data(), + beta_array.data(), + (void**)d_c.data_ptr(), + cuda_data_type, + ldc_array.data(), + group_count, + group_size.data(), + CUBLAS_COMPUTE_32F); TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status)); TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization"); return; #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); + TORCH_CHECK_NOT_IMPLEMENTED( + false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu index 337a5ad69..a62a5c0ce 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu @@ -35,8 +35,12 @@ using namespace cute; template -void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b) { +void launch_sm90_fp8_blockwise_scaled_mm( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b) { using ElementAccumulator = float; using ElementCompute = float; using ElementBlockScale = float; @@ -66,19 +70,43 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, - LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, EpilogueSchedule, StoreEpilogueCompute>::CollectiveOp; + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueSchedule, + StoreEpilogueCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, - TileShape, ClusterShape, + ArchTag, + OperatorClass, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, KernelSchedule>::CollectiveOp; - using GemmKernel = - cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; Gemm gemm_op; @@ -127,16 +155,23 @@ void launch_sm90_fp8_blockwise_scaled_mm(torch::Tensor& out, const torch::Tensor } template -void sm90_fp8_blockwise_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b) { +void sm90_fp8_blockwise_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b) { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _1, _1>; launch_sm90_fp8_blockwise_scaled_mm(out, a, b, scales_a, scales_b); } -torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const torch::Dtype& out_dtype) { +torch::Tensor fp8_blockwise_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype) { TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); @@ -145,10 +180,10 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); - TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, - "mat_a must be multiple of 16 bytes for memory alignment"); - TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, - "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK( + (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK( + (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); @@ -186,6 +221,6 @@ torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::T #endif #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); + TORCH_CHECK_NOT_IMPLEMENTED( + false, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu index 36b9585f3..64731ebe4 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu @@ -53,10 +53,17 @@ limitations under the License. using namespace cute; #if defined CUDA_VERSION && CUDA_VERSION >= 12040 -template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, - typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> +template < + typename ElementType, + typename OutElementType, + typename AccumElementType, + typename CtaShape, + typename WarpShape, + int Stages, + bool WithBias, + typename FP8MathOperator = cutlass::arch::OpMultiplyAdd, + template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, + typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> struct DeviceGemmFp8RowwiseSm89 { static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); @@ -85,56 +92,86 @@ struct DeviceGemmFp8RowwiseSm89 { // Number of epilogue stages in EVT static constexpr int EVTEpilogueStages = 1; - using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; + using OutputTileThreadMap = cutlass::epilogue::threadblock:: + OutputTileThreadLayout; // Definition of EVT using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>; - using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + cutlass::multiplies, + ElementComputeEpilogue, + ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + using bScaleSrc = cutlass::epilogue::threadblock:: + VisitorRowBroadcast>; using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; - using ComputeAScale = - cutlass::epilogue::threadblock::VisitorCompute; - using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; + using ComputeAScale = cutlass::epilogue::threadblock:: + VisitorCompute; + using aScaleSrc = cutlass::epilogue::threadblock:: + VisitorColBroadcast>; using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; // With bias using biasSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; - using ComputeAScaleWithBias = - cutlass::epilogue::threadblock::VisitorCompute; + using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, + ElementC, + ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; using EpilogueAScaleWithBias = cutlass::epilogue::threadblock::Sm80EVT; using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< - OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride>; - using EpilogueStore = - typename cutlass::platform::conditional, - cutlass::epilogue::threadblock::Sm80EVT>::type; + OutputTileThreadMap, + ElementC, + cutlass::FloatRoundStyle::round_to_nearest, + Stride>; + using EpilogueStore = typename cutlass::platform::conditional< + WithBias, + cutlass::epilogue::threadblock::Sm80EVT, + cutlass::epilogue::threadblock::Sm80EVT>::type; using EpilogueOp = EpilogueStore; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< - ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, - cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator, - ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp, - ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel; + ElementA, + LayoutA, + cutlass::ComplexTransform::kNone, + AlignmentA, + ElementB, + LayoutB, + cutlass::ComplexTransform::kNone, + AlignmentB, + ElementC, + LayoutC, + AlignmentC, + ElementAccumulator, + ElementComputeEpilogue, + OperatorClass, + ArchTag, + CtaShape, + WarpShape, + InstructionShape, + EpilogueOp, + ThreadblockSwizzle, + Stages, + FP8MathOperator, + EVTEpilogueStages>::GemmKernel; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; template -typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +typename Gemm::Arguments prepare_sm89_fp8_args( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ElementT = typename Gemm::ElementA; using ElementOutput = typename Gemm::ElementD; using ElementComputeEpilogue = float; @@ -158,54 +195,61 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch:: ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); - typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode - {m, n, k}, // Problem size - 1, // Split-k factor - {}, // Epilogue args - ptr_a, // a pointer - ptr_b, // b pointer - nullptr, // c pointer (unused) - nullptr, // d pointer (unused) - m * k, // batch stride a (unused) - n * k, // batch stride b (unused) - m * n, // batch stride c (unused) - m * n, // batch stride d (unused) - lda, // stride a - ldb, // stride b - ldc, // stride c (unused) - ldc); // stride d (unused) + typename Gemm::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + ptr_a, // a pointer + ptr_b, // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) if constexpr (WithBias) { - args.epilogue = {{ - { - {}, // Accumulator - {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, - {} // Multiplies - }, - {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, - {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, - {} // Multiplies - }, - {ptr_d, {n, _1{}, _0{}}}}; + args.epilogue = { + { + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; } else { - args.epilogue = {{ - { - {}, // Accumulator - {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, - {} // Multiplies - }, - {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, - {} // Multiplies - }, - {ptr_d, {n, _1{}, _0{}}}}; + args.epilogue = { + { + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; } return args; } template -void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void launch_sm89_fp8_scaled_mm( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); Gemm gemm_op; @@ -222,109 +266,187 @@ void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const } template -void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm89_fp8_dispatch_bias( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ElementInput = cutlass::float_e4m3_t; using ElementOutput = OutType; using AccumElementType = float; if (bias) { - using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + using Gemm = typename DeviceGemmFp8RowwiseSm89< + ElementInput, + ElementOutput, + AccumElementType, + CtaShape, + WarpShape, + Stages, + true>::Gemm; return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); } else { - using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + using Gemm = typename DeviceGemmFp8RowwiseSm89< + ElementInput, + ElementOutput, + AccumElementType, + CtaShape, + WarpShape, + Stages, + false>::Gemm; return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); } } template -void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm89_fp8_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { uint32_t const m = a.size(0); uint32_t const n = out.size(1); if (m == 1) { if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 16) { // M in (1, 16] if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 4>(out, a, b, scales_a, scales_b, bias); } else if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 64) { // M in (16, 64] if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 128) { // M in (64, 128] if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + 4>(out, a, b, scales_a, scales_b, bias); } else if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 256) { // M in (128, 256] if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 5>(out, a, b, scales_a, scales_b, bias); } else if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 7>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, + 4>(out, a, b, scales_a, scales_b, bias); } } else if (m <= 512) { // M in (256, 512) if (n <= 16384) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 2>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 4>(out, a, b, scales_a, scales_b, bias); } } else { // M in (512, inf) if (n <= 8192) { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 3>(out, a, b, scales_a, scales_b, bias); } else { - return sm89_fp8_dispatch_bias, - cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + return sm89_fp8_dispatch_bias< + OutType, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + 2>(out, a, b, scales_a, scales_b, bias); } } } #endif #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -template +template < + typename ElementType, + typename OutElementType, + typename AccumElementType, + typename CTAShape, + typename ClusterShape, + typename MainloopScheduleType, + typename EpilogueScheduleType, + typename TileSchedulerType = void, + bool WithBias = false> struct DeviceGemmFp8RowwiseSm90 { static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); @@ -374,44 +496,70 @@ struct DeviceGemmFp8RowwiseSm90 { using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default // setting in the Collective Builder // Implement rowwise scaling epilogue. - using XScale = - cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, - cute::Stride, cute::Int<0>, cute::Int<0>>>; + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; - using WScale = - cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, - cute::Stride, cute::Int<1>, cute::Int<0>>>; + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0, + TileShape, + ElementComputeEpilogue, + ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, - cute::Stride, cute::Int<1>, cute::Int<0>>>; + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0, + TileShape, + ElementOutput, + ElementOutput, + cute::Stride, cute::Int<1>, cute::Int<0>>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementComputeEpilogue, // First stage output type. + ElementComputeEpilogue, // First stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; - using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, + ElementOutput, + ElementComputeEpilogue, // Second stage input types. + cutlass::FloatRoundStyle::round_to_nearest>; using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; // With bias - using ComputeWithBias = - cutlass::epilogue::fusion::Sm90Compute; + using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, + ElementOutput, + ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; using EpilogueEVT = typename cutlass::platform::conditional::type; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC, - AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized, + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementComputeEpilogue, + ElementC, + LayoutC, + AlignmentC, + ElementOutput, + LayoutOutput, + AlignmentOutput, + cutlass::epilogue::TmaWarpSpecialized, EpilogueEVT>::CollectiveOp; using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; @@ -423,22 +571,38 @@ struct DeviceGemmFp8RowwiseSm90 { using FastAccum = FastPongSchedule; // Default apply Pingpong using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, - TileShape, ClusterShape, + ArchTag, + OperatorClass, + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, MainloopScheduleType>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + TileSchedulerType>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; template -typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +typename Gemm::Arguments prepare_sm90_fp8_args( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ElementT = typename Gemm::ElementA; using ElementOutput = typename Gemm::ElementD; using ElementComputeEpilogue = float; @@ -465,14 +629,15 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); StrideC stride_c; StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); - typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, - {m, n, k, 1}, - {ptr_a, stride_a, ptr_b, stride_b}, - {{}, // epilogue.thread - nullptr, - stride_c, - ptr_d, - stride_d}}; + typename Gemm::Arguments args = { + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {ptr_a, stride_a, ptr_b, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + ptr_d, + stride_d}}; if constexpr (WithBias) { args.epilogue.thread = { {ptr_scales_a}, @@ -500,9 +665,13 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: } template -void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void launch_sm90_fp8_scaled_mm( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); Gemm gemm_op; @@ -519,66 +688,117 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const TORCH_CHECK(status == cutlass::Status::kSuccess) } -template -void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias, bool fast_accum = true, - bool use_persistent = false) { +template < + typename OutType, + typename CTAShape, + typename ClusterShape, + typename MainloopScheduleType, + typename TileSchedulerType> +void sm90_fp8_dispatch_bias( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias, + bool fast_accum = true, + bool use_persistent = false) { using ElementInput = cutlass::float_e4m3_t; using ElementOutput = OutType; using AccumElementType = float; using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; if (bias) { - using Gemm = - typename DeviceGemmFp8RowwiseSm90::Gemm; + using Gemm = typename DeviceGemmFp8RowwiseSm90< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape, + ClusterShape, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + true>::Gemm; return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); } else { - using Gemm = - typename DeviceGemmFp8RowwiseSm90::Gemm; + using Gemm = typename DeviceGemmFp8RowwiseSm90< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape, + ClusterShape, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + false>::Gemm; return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); } } template -void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm90_fp8_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { uint32_t const m = a.size(0); using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; using BasicTileScheduler = void; if (m <= 1) { - return sm90_fp8_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler, - BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_64, _64, _128>, + Shape<_1, _8, _1>, + FastBasicScheduler, + BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); } if (m <= 64) { // m in [1, 64] - return sm90_fp8_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler, - PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_64, _64, _128>, + Shape<_1, _4, _1>, + FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } else if (m <= 256) { // m in (64, 256] - return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, - PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_64, _64, _128>, + Shape<_1, _1, _1>, + FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } else if (m <= 1024) { // m in (256, 1024] - return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, - PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_128, _128, _128>, + Shape<_1, _1, _1>, + FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } else { // m in (1024, inf) - return sm90_fp8_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler, - PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + return sm90_fp8_dispatch_bias< + OutType, + Shape<_128, _128, _128>, + Shape<_2, _1, _1>, + FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); } } #endif -torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias) { +torch::Tensor fp8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, + const c10::optional& bias) { TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); @@ -587,10 +807,10 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); - TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, - "mat_a must be multiple of 16 bytes for memory alignment"); - TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, - "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK( + (mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK( + (mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu index 4a8130d66..86aa3b8f2 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu @@ -35,11 +35,20 @@ limitations under the License. using namespace cute; -template -void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +template < + typename ElementOutput, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + int NumStages> +void cutlass_int8_scaled_mm( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ElementAccumulator = int32_t; using ElementCompute = float; using ElementInputA = int8_t; @@ -48,30 +57,51 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons using OperatorClass = cutlass::arch::OpClassTensorOp; using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; - using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration; + using DefaultGemmConf = cutlass::gemm::device:: + DefaultGemmConfiguration; using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp; using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< - ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, ElementInputB, - cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, ElementOutput, cutlass::layout::RowMajor, - ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, - ThreadblockSwizzle, NumStages, true, typename DefaultGemmConf::Operator>::GemmKernel; + ElementInputA, + cutlass::layout::RowMajor, + DefaultGemmConf::kAlignmentA, + ElementInputB, + cutlass::layout::ColumnMajor, + DefaultGemmConf::kAlignmentB, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + NumStages, + true, + typename DefaultGemmConf::Operator>::GemmKernel; using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape, typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count, GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads, - GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits::value>, + GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, + cutlass::sizeof_bits::value>, ElementCompute>; using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol< - ThreadblockShape, GemmKernel_::kThreadCount, AlphaColTileIterator, - typename GemmKernel_::Epilogue::OutputTileIterator, ElementAccumulator, ElementCompute, EpilogueOutputOp>; + ThreadblockShape, + GemmKernel_::kThreadCount, + AlphaColTileIterator, + typename GemmKernel_::Epilogue::OutputTileIterator, + ElementAccumulator, + ElementCompute, + EpilogueOutputOp>; - using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< - EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpilogueWithVisitorFromExistingEpilogue::Epilogue; using GemmKernel = cutlass::gemm::kernel::GemmWithEpilogueVisitor; @@ -104,98 +134,164 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons typename EpilogueOutputOp::Params linearScalingParams; typename EpilogueVisitor::Arguments visitor_args{linearScalingParams}; - typename Gemm::Arguments args{{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, - {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args}; + typename Gemm::Arguments args{ + {m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args}; - auto workspace = torch::empty(gemm_op.get_workspace_size(args), - torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + auto workspace = torch::empty( + gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); auto can_implement = gemm_op.can_implement(args); - TORCH_CHECK(can_implement == cutlass::Status::kSuccess, - "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + TORCH_CHECK( + can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", + cutlassGetStatusString(can_implement)); auto status = gemm_op(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); } template -void sm75_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm75_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { int m = mat_a.size(0); if (m <= 32) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<32, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, + 2>(out, mat_a, mat_b, scales_a, scales_b, bias); } else if (m <= 64) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 2>(out, mat_a, mat_b, scales_a, scales_b, bias); } else if (m <= 256) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 2>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 2>(out, mat_a, mat_b, scales_a, scales_b, bias); } } template -void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm80_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { int m = mat_a.size(0); int n = mat_b.size(1); if (m <= 16) { if (n <= 4096) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, + 6>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<16, 64, 128>, + cutlass::gemm::GemmShape<16, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 32) { if (n <= 4096) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, + 6>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<32, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 64) { if (n <= 4096) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 128 && n < 8192) { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - cutlass_int8_scaled_mm, - cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, - scales_b, bias); + cutlass_int8_scaled_mm< + ElementOutput, + ArchTag, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + InstructionShape, + 5>(out, mat_a, mat_b, scales_a, scales_b, bias); } } -template -void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +template < + typename ElementOutput, + typename TileShape, + typename ClusterShape, + typename MainloopScheduleType, + bool WithBias> +void cutlass_int8_scaled_mm_sm90( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { using ArchTag = cutlass::arch::Sm90; using ElementAccumulator = int32_t; @@ -213,50 +309,75 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; using TileSchedulerType = cutlass::gemm::PersistentScheduler; - using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, - Stride, Int<0>, Int<0>>>; + using XScale = cutlass::epilogue::fusion:: + Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride, Int<0>, Int<0>>>; - using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, - Stride, Int<1>, Int<0>>>; + using WScale = cutlass::epilogue::fusion:: + Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, Stride, Int<1>, Int<0>>>; - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, - Stride, Int<1>, Int<0>>>; + using Bias = cutlass::epilogue::fusion:: + Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, Stride, Int<1>, Int<0>>>; using Accum = cutlass::epilogue::fusion::Sm90AccFetch; // Scale - using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + using Compute0 = cutlass::epilogue::fusion:: + Sm90Compute; using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; - using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + using Compute1 = cutlass::epilogue::fusion:: + Sm90Compute; using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; // With bias - using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute; + using ComputeWithBias = cutlass::epilogue::fusion:: + Sm90Compute; using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; using EpilogueEVT = typename cutlass::platform::conditional::type; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput, - cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementOutput, + cutlass::layout::RowMajor, + AlignmentC, + ElementOutput, + cutlass::layout::RowMajor, + AlignmentOutput, + EpilogueScheduleType, + EpilogueEVT>::CollectiveOp; using Stages = cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB, - cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages, + ArchTag, + OperatorClass, + ElementInputA, + cutlass::layout::RowMajor, + AlignmentA, + ElementInputB, + cutlass::layout::ColumnMajor, + AlignmentB, + ElementAccumulator, + TileShape, + ClusterShape, + Stages, MainloopScheduleType>::CollectiveOp; - using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + TileSchedulerType>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -283,14 +404,15 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, StrideC stride_c; StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); - typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, - {m, n, k, 1}, - {a_ptr, stride_a, b_ptr, stride_b}, - {{}, // epilogue.thread - nullptr, - stride_c, - o_ptr, - stride_d}}; + typename Gemm::Arguments args = { + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {a_ptr, stride_a, b_ptr, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + o_ptr, + stride_d}}; if constexpr (WithBias) { ElementOutput* bias_ptr = static_cast(bias->data_ptr()); @@ -308,23 +430,29 @@ void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, }; } - auto workspace = torch::empty(gemm_op.get_workspace_size(args), - torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + auto workspace = torch::empty( + gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); auto can_implement = gemm_op.can_implement(args); - TORCH_CHECK(can_implement == cutlass::Status::kSuccess, - "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + TORCH_CHECK( + can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", + cutlassGetStatusString(can_implement)); auto status = gemm_op(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); } template -void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm90_dispatch_bias( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { if (bias) { cutlass_int8_scaled_mm_sm90( out, mat_a, mat_b, scales_a, scales_b, bias); @@ -335,45 +463,73 @@ void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const to } template -void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const c10::optional& bias) { +void sm90_dispatch_shape( + torch::Tensor& out, + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { int m = mat_a.size(0); int n = mat_b.size(1); if (m <= 32) { if (n < 8192) { - return sm90_dispatch_bias, Shape<_1, _8, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _64, _128>, + Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - return sm90_dispatch_bias, Shape<_1, _8, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _128, _128>, + Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 64) { if (n < 8192) { - return sm90_dispatch_bias, Shape<_1, _4, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _64, _128>, + Shape<_1, _4, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - return sm90_dispatch_bias, Shape<_1, _1, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _64, _256>, + Shape<_1, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else if (m <= 128) { if (n <= 4096) { - return sm90_dispatch_bias, Shape<_2, _1, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _64, _128>, + Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } else { - return sm90_dispatch_bias, Shape<_2, _1, _1>, - cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_64, _128, _128>, + Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); } } else { - return sm90_dispatch_bias, Shape<_2, _1, _1>, - cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, - bias); + return sm90_dispatch_bias< + ElementOutput, + Shape<_128, _128, _128>, + Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, bias); } } -torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias) { +torch::Tensor int8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, + const c10::optional& bias) { TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index d9290fe01..ea222c001 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -8,8 +8,8 @@ #include "utils.h" template -__global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, - const int64_t num_elements) { +__global__ void +per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output_s, const int64_t num_elements) { float max_value = 0.0f; unsigned int tid = threadIdx.x; unsigned int gid = blockIdx.x * blockDim.x + threadIdx.x; @@ -56,8 +56,11 @@ __global__ void per_tensor_absmax_kernel(const T* __restrict__ input, float* __r } template -__global__ void per_tensor_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output, - const float* __restrict__ scale, const int64_t num_elements) { +__global__ void per_tensor_quant_fp8_kernel( + const T* __restrict__ input, + FP8_TYPE* __restrict__ output, + const float* __restrict__ scale, + const int64_t num_elements) { const int gid = blockIdx.x * blockDim.x + threadIdx.x; const int grid_size = blockDim.x * gridDim.x; const float scale_val = 1.0f / (*scale); @@ -124,8 +127,10 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch } per_tensor_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), - static_cast(output_s.data_ptr()), num_elements); + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + num_elements); return true; }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu index e5a14602a..bb3135dad 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu @@ -17,10 +17,15 @@ __device__ __forceinline__ float GroupReduce(float val, const int tid) { } template -__global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, void* __restrict__ output_q, - float* __restrict__ output_s, const int group_size, - const int num_groups, const float eps, const float fp8_min, - const float fp8_max) { +__global__ void per_token_group_quant_fp8_kernel( + const T* __restrict__ input, + void* __restrict__ output_q, + float* __restrict__ output_s, + const int group_size, + const int num_groups, + const float eps, + const float fp8_min, + const float fp8_max) { const int groups_per_block = 16; const int local_group_id = threadIdx.x / 16; const int lane_id = threadIdx.x % 16; @@ -80,8 +85,14 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo } } -void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, - int64_t group_size, double eps, double fp8_min, double fp8_max) { +void sgl_per_token_group_quant_fp8( + torch::Tensor input, + torch::Tensor output_q, + torch::Tensor output_s, + int64_t group_size, + double eps, + double fp8_min, + double fp8_max) { CHECK_INPUT(input); CHECK_INPUT(output_q); CHECK_INPUT(output_s); @@ -97,8 +108,14 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q, DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { per_token_group_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), output_q.data_ptr(), static_cast(output_s.data_ptr()), - group_size, num_groups, (float)eps, (float)fp8_min, (float)fp8_max); + static_cast(input.data_ptr()), + output_q.data_ptr(), + static_cast(output_s.data_ptr()), + group_size, + num_groups, + (float)eps, + (float)fp8_min, + (float)fp8_max); return true; }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index 5528ad8c5..1491af126 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -7,9 +7,12 @@ #include "utils.h" template -__global__ void per_token_quant_fp8_kernel(const T* __restrict__ input, FP8_TYPE* __restrict__ output_q, - float* __restrict__ output_s, const int64_t hidden_dim, - const int64_t num_tokens) { +__global__ void per_token_quant_fp8_kernel( + const T* __restrict__ input, + FP8_TYPE* __restrict__ output_q, + float* __restrict__ output_s, + const int64_t hidden_dim, + const int64_t num_tokens) { const int token_idx = blockIdx.x; if (token_idx >= num_tokens) return; @@ -110,8 +113,11 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { per_token_quant_fp8_kernel<<>>( - static_cast(input.data_ptr()), static_cast(output_q.data_ptr()), - static_cast(output_s.data_ptr()), hidden_dim, num_tokens); + static_cast(input.data_ptr()), + static_cast(output_q.data_ptr()), + static_cast(output_s.data_ptr()), + hidden_dim, + num_tokens); return true; }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu index 473aae6f5..c5f37e556 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -25,9 +25,11 @@ limitations under the License. #define WARP_SIZE 32 template -__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, - int32_t* __restrict__ cumsum_buffer, size_t numel) { +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + size_t numel) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; @@ -39,10 +41,15 @@ __global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ } template -__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, - int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum) { __shared__ int32_t shared_counts[WARP_SIZE][8]; const int warp_id = threadIdx.x / WARP_SIZE; @@ -91,17 +98,29 @@ __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_id } } -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, + torch::Tensor cumsum_buffer) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now."); DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto align_kernel = moe_align_block_size_kernel; - align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), - num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); + align_kernel<<<1, 1024, 0, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + cumsum_buffer.data_ptr()); const int block_threads = 256; const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; @@ -109,8 +128,10 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b const int actual_blocks = std::min(num_blocks, max_blocks); auto sort_kernel = count_and_sort_expert_tokens_kernel; - sort_kernel<<>>(topk_ids.data_ptr(), - sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel()); + sort_kernel<<>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + topk_ids.numel()); }); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu index af44261cc..1bfd6fd84 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative/eagle_utils.cu @@ -23,10 +23,18 @@ // tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = // [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, // draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token] -__global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, - bool* tree_mask, int64_t* positions, int64_t* retrive_index, - int64_t* retrive_next_token, int64_t* retrive_next_sibling, int topk, int depth, - int draft_token_num) { +__global__ void build_tree_efficient( + int64_t* parent_list, + int64_t* selected_index, + int32_t* verified_seq_len, + bool* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int64_t* retrive_next_token, + int64_t* retrive_next_sibling, + int topk, + int depth, + int draft_token_num) { int bid = blockIdx.x; int tid = threadIdx.x; @@ -99,10 +107,18 @@ __global__ void build_tree_efficient(int64_t* parent_list, int64_t* selected_ind } } -void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, - at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, - at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, - int64_t depth, int64_t draft_token_num) { +void build_tree_kernel_efficient( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + int64_t topk, + int64_t depth, + int64_t draft_token_num) { // TODO (ying) check shape // TODO (ying) check type int bs = parent_list.size(0); @@ -111,11 +127,17 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); build_tree_efficient<<>>( - static_cast(parent_list.data_ptr()), static_cast(selected_index.data_ptr()), - static_cast(verified_seq_len.data_ptr()), static_cast(tree_mask.data_ptr()), - static_cast(positions.data_ptr()), static_cast(retrive_index.data_ptr()), - static_cast(retrive_next_token.data_ptr()), static_cast(retrive_next_sibling.data_ptr()), - int32_t(topk), int32_t(depth), int32_t(draft_token_num)); + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num)); } // parent_list [bs, topk * (depth - 1) + 1)] @@ -124,8 +146,16 @@ void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_ind // tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = // [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, // draft_token, depth + 2] -__global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_t* verified_seq_len, bool* tree_mask, - int64_t* positions, int64_t* retrive_index, int topk, int depth, int draft_token_num) { +__global__ void build_tree( + int64_t* parent_list, + int64_t* selected_index, + int32_t* verified_seq_len, + bool* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int topk, + int depth, + int draft_token_num) { int bid = blockIdx.x; int tid = threadIdx.x; @@ -191,9 +221,16 @@ __global__ void build_tree(int64_t* parent_list, int64_t* selected_index, int32_ } } -void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, - at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, - int64_t depth, int64_t draft_token_num) { +void build_tree_kernel( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + int64_t topk, + int64_t depth, + int64_t draft_token_num) { // TODO (ying) check shape // TODO (ying) check type int bs = parent_list.size(0); @@ -202,8 +239,13 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); build_tree<<>>( - static_cast(parent_list.data_ptr()), static_cast(selected_index.data_ptr()), - static_cast(verified_seq_len.data_ptr()), static_cast(tree_mask.data_ptr()), - static_cast(positions.data_ptr()), static_cast(retrive_index.data_ptr()), int32_t(topk), - int32_t(depth), int32_t(draft_token_num)); + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num)); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu index 379a2a22c..6eaafdb5b 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cu @@ -29,12 +29,19 @@ using namespace flashinfer; // retrive_next_sibling: [bs, num_draft_tokens] // uniform_samples: [bs, num_draft_tokens] // target_probs: [bs, num_draft_tokens, vocab_size] -void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, - at::Tensor accept_token_num, // mutable - at::Tensor candidates, at::Tensor retrive_index, - at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, - at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, - bool deterministic, int64_t cuda_stream = 0) { +void tree_speculative_sampling_target_only( + at::Tensor predicts, + at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + at::Tensor uniform_samples, + at::Tensor target_probs, + at::Tensor draft_probs, + bool deterministic, + int64_t cuda_stream = 0) { CHECK_INPUT(candidates); CHECK_INPUT(retrive_index); CHECK_INPUT(retrive_next_token); @@ -108,13 +115,24 @@ void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accep cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly( - static_cast(predicts.data_ptr()), static_cast(accept_index.data_ptr()), - static_cast(accept_token_num.data_ptr()), static_cast(candidates.data_ptr()), - static_cast(retrive_index.data_ptr()), static_cast(retrive_next_token.data_ptr()), - static_cast(retrive_next_sibling.data_ptr()), static_cast(uniform_samples.data_ptr()), - static_cast(target_probs.data_ptr()), static_cast(draft_probs.data_ptr()), batch_size, - num_spec_step, num_draft_tokens, vocab_size, deterministic, stream); + static_cast(predicts.data_ptr()), + static_cast(accept_index.data_ptr()), + static_cast(accept_token_num.data_ptr()), + static_cast(candidates.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + static_cast(uniform_samples.data_ptr()), + static_cast(target_probs.data_ptr()), + static_cast(draft_probs.data_ptr()), + batch_size, + num_spec_step, + num_draft_tokens, + vocab_size, + deterministic, + stream); - TORCH_CHECK(status == cudaSuccess, - "TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status))); + TORCH_CHECK( + status == cudaSuccess, + "TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status))); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh index b9a32d2a9..bf7099231 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/speculative/speculative_sampling.cuh @@ -27,15 +27,29 @@ namespace sampling { using namespace cub; -template -__global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* accept_index, - IdType* accept_token_num, // mutable - IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, - IdType* retrive_next_sibling, DType* uniform_samples, - DType* target_probs, DType* draft_probs, uint32_t batch_size, - uint32_t num_speculative_tokens, uint32_t num_draft_tokens, - uint32_t d) { +template < + uint32_t BLOCK_THREADS, + BlockScanAlgorithm SCAN_ALGORITHM, + BlockReduceAlgorithm REDUCE_ALGORITHM, + uint32_t VEC_SIZE, + bool DETERMINISTIC, + typename DType, + typename IdType> +__global__ void TreeSpeculativeSamplingTargetOnly( + IdType* predicts, + IdType* accept_index, + IdType* accept_token_num, // mutable + IdType* candidates, + IdType* retrive_index, + IdType* retrive_next_token, + IdType* retrive_next_sibling, + DType* uniform_samples, + DType* target_probs, + DType* draft_probs, + uint32_t batch_size, + uint32_t num_speculative_tokens, + uint32_t num_draft_tokens, + uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; extern __shared__ __align__(alignof(SamplingTempStorage)) @@ -140,37 +154,54 @@ __global__ void TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* acce } template -cudaError_t TreeSpeculativeSamplingTargetOnly(IdType* predicts, IdType* output_token_ids, - IdType* output_accepted_token_num, // mutable - IdType* candidates, IdType* retrive_index, IdType* retrive_next_token, - IdType* retrive_next_sibling, DType* uniform_samples, DType* target_probs, - DType* draft_probs, uint32_t batch_size, uint32_t num_speculative_tokens, - uint32_t num_draft_tokens, uint32_t d, bool deterministic, - cudaStream_t stream = 0) { +cudaError_t TreeSpeculativeSamplingTargetOnly( + IdType* predicts, + IdType* output_token_ids, + IdType* output_accepted_token_num, // mutable + IdType* candidates, + IdType* retrive_index, + IdType* retrive_next_token, + IdType* retrive_next_sibling, + DType* uniform_samples, + DType* target_probs, + DType* draft_probs, + uint32_t batch_size, + uint32_t num_speculative_tokens, + uint32_t num_draft_tokens, + uint32_t d, + bool deterministic, + cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&predicts, - &output_token_ids, - &output_accepted_token_num, - &candidates, - &retrive_index, - &retrive_next_token, - &retrive_next_sibling, - &uniform_samples, - &target_probs, - &draft_probs, - &batch_size, - &num_speculative_tokens, - &num_draft_tokens, - &d}; + void* args[] = { + &predicts, + &output_token_ids, + &output_accepted_token_num, + &candidates, + &retrive_index, + &retrive_next_token, + &retrive_next_sibling, + &uniform_samples, + &target_probs, + &draft_probs, + &batch_size, + &num_speculative_tokens, + &num_draft_tokens, + &d}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TreeSpeculativeSamplingTargetOnly; + auto kernel = TreeSpeculativeSamplingTargetOnly< + BLOCK_THREADS, + SCAN_ALGO, + REDUCE_ALGO, + VEC_SIZE, + DETERMINISTIC, + DType, + IdType>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index f5ebffb12..5bc5c7083 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -42,8 +42,8 @@ using fptr_t = int64_t; void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); -void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, - int64_t cuda_stream); +void gemma_fused_add_rmsnorm( + at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); @@ -53,113 +53,219 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); */ #ifdef USE_ROCM // ROCM custom allreduce -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, - const std::vector& offsets, int64_t rank, bool full_nvlink); +fptr_t init_custom_ar( + torch::Tensor& meta, + torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, + int64_t rank, + bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); void dispose(fptr_t _fa); int64_t meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, - const std::vector& offsets); +void register_buffer( + fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); std::tuple> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector& handles, - const std::vector>& offsets); +void register_graph_buffers( + fptr_t _fa, const std::vector& handles, const std::vector>& offsets); torch::Tensor allocate_meta_buffer(int64_t size); torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); #else // TRTLLM custom allreduce -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, - const std::vector& tmp_result_buffers, const std::vector& barrier_in, - const std::vector& barrier_out); +fptr_t init_custom_ar( + int64_t rank_id, + int64_t world_size, + torch::Tensor& rank_data, + const std::vector& buffers, + const std::vector& tmp_result_buffers, + const std::vector& barrier_in, + const std::vector& barrier_out); void dispose(fptr_t _fa); void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector>& handles, - const std::vector>& offsets); +void register_graph_buffers( + fptr_t _fa, const std::vector>& handles, const std::vector>& offsets); #endif /* * From csrc/gemm */ -torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); -torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); -torch::Tensor fp8_blockwise_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, - const torch::Tensor& scales_a, const torch::Tensor& scales_b, - const torch::Dtype& out_dtype); -void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, - double eps, double fp8_min, double fp8_max); +torch::Tensor int8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, + const c10::optional& bias); +torch::Tensor fp8_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, + const c10::optional& bias); +torch::Tensor fp8_blockwise_scaled_mm( + const torch::Tensor& mat_a, + const torch::Tensor& mat_b, + const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const torch::Dtype& out_dtype); +void sgl_per_token_group_quant_fp8( + at::Tensor input, + at::Tensor output_q, + at::Tensor output_s, + int64_t group_size, + double eps, + double fp8_min, + double fp8_max); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); -void cublas_grouped_gemm(const std::vector& inputs, const std::vector& weights, - const std::vector& outputs, const torch::Dtype& out_dtype, - int64_t cublas_handle, int64_t cuda_stream); +void cublas_grouped_gemm( + const std::vector& inputs, + const std::vector& weights, + const std::vector& outputs, + const torch::Dtype& out_dtype, + int64_t cublas_handle, + int64_t cuda_stream); /* * From csrc/moe */ -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, + torch::Tensor cumsum_buffer); /* * From csrc/speculative */ -void tree_speculative_sampling_target_only(at::Tensor predicts, at::Tensor accept_index, - at::Tensor accept_token_num, // mutable - at::Tensor candidates, at::Tensor retrive_index, - at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, - at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, - bool deterministic = true, int64_t cuda_stream = 0); +void tree_speculative_sampling_target_only( + at::Tensor predicts, + at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + at::Tensor uniform_samples, + at::Tensor target_probs, + at::Tensor draft_probs, + bool deterministic = true, + int64_t cuda_stream = 0); -void build_tree_kernel_efficient(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, - at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, - at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, int64_t topk, - int64_t depth, int64_t draft_token_num); +void build_tree_kernel_efficient( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + int64_t topk, + int64_t depth, + int64_t draft_token_num); -void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Tensor verified_seq_len, - at::Tensor tree_mask, at::Tensor positions, at::Tensor retrive_index, int64_t topk, - int64_t depth, int64_t draft_token_num); +void build_tree_kernel( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + int64_t topk, + int64_t depth, + int64_t draft_token_num); /* * From FlashInfer */ -void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, - at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); -void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - std::optional maybe_min_p_arr, double min_p_val, bool deterministic, - int64_t cuda_stream); +void bmm_fp8( + at::Tensor A, + at::Tensor B, + at::Tensor D, + at::Tensor A_scale, + at::Tensor B_scale, + at::Tensor workspace_buffer, + int64_t cublas_handle, + int64_t cuda_stream); +void min_p_sampling_from_probs( + at::Tensor probs, + at::Tensor uniform_samples, + at::Tensor samples, + std::optional maybe_min_p_arr, + double min_p_val, + bool deterministic, + int64_t cuda_stream); // top k renorm probs // patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. -void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, - unsigned int top_k_val, int64_t cuda_stream); +void top_k_renorm_probs( + at::Tensor probs, + at::Tensor renorm_probs, + std::optional maybe_top_k_arr, + unsigned int top_k_val, + int64_t cuda_stream); // patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. -inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_k_arr, int64_t top_k_val, - int64_t cuda_stream) { +inline void top_k_renorm_probs_wrapper( + at::Tensor probs, + at::Tensor renorm_probs, + std::optional maybe_top_k_arr, + int64_t top_k_val, + int64_t cuda_stream) { top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); } -void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, - double top_p_val, int64_t cuda_stream); -void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, - at::Tensor success, std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic, - int64_t cuda_stream); -void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic, - int64_t cuda_stream); -void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, - int64_t cuda_stream); +void top_p_renorm_probs( + at::Tensor probs, + at::Tensor renorm_probs, + std::optional maybe_top_p_arr, + double top_p_val, + int64_t cuda_stream); +void top_k_top_p_sampling_from_probs( + at::Tensor probs, + at::Tensor uniform_samples, + at::Tensor samples, + at::Tensor success, + std::optional maybe_top_k_arr, + double top_k_val, + std::optional maybe_top_p_arr, + double top_p_val, + bool deterministic, + int64_t cuda_stream); +void top_p_sampling_from_probs( + at::Tensor probs, + at::Tensor uniform_samples, + at::Tensor samples, + at::Tensor success, + std::optional maybe_top_p_arr, + double top_p_val, + bool deterministic, + int64_t cuda_stream); +void apply_rope_pos_ids_cos_sin_cache( + at::Tensor q, + at::Tensor k, + at::Tensor q_rope, + at::Tensor k_rope, + at::Tensor cos_sin_cache, + at::Tensor pos_ids, + bool interleave, + int64_t cuda_stream); /* * Other */ -void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, - const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, - torch::Tensor new_kv); +void lightning_attention_decode( + const torch::Tensor& q, + const torch::Tensor& k, + const torch::Tensor& v, + const torch::Tensor& past_kv, + const torch::Tensor& slope, + torch::Tensor output, + torch::Tensor new_kv); // sgl_per_token_quant_fp8 void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); diff --git a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh index f4b01230c..c670c994d 100644 --- a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh @@ -103,7 +103,7 @@ inline AllReduceStrategyType SelectImplementation(size_t message_size, int world return AllReduceStrategyType::TWOSHOT; } -void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, - cudaStream_t stream); +void trtCustomAllReduce( + AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream); } // namespace trt_llm diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index 94bcefa7f..b2960954b 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -95,7 +95,6 @@ inline int getSMVersion() { AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) #define CEILDIV(x, y) (((x) + (y)-1) / (y)) - #define WARP_SIZE 32 #ifndef USE_ROCM