minor: cleanup sgl-kernel (#3143)
This commit is contained in:
@@ -40,6 +40,10 @@ Development build:
|
||||
make build
|
||||
```
|
||||
|
||||
Note:
|
||||
|
||||
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
|
||||
|
||||
### Testing & Benchmarking
|
||||
|
||||
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
|
||||
|
||||
@@ -82,10 +82,8 @@ sources = [
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
|
||||
"src/sgl-kernel/csrc/rotary_embedding.cu",
|
||||
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
|
||||
"3rdparty/flashinfer/csrc/activation.cu",
|
||||
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
|
||||
"3rdparty/flashinfer/csrc/group_gemm.cu",
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
// Adapted from
|
||||
// https://github.com/InternLM/lmdeploy/blob/800b6010c0bf76aadf678bc38a507b749fb9774c/src/turbomind/kernels/norm/rms_norm.cu
|
||||
|
||||
#include <turbomind/kernels/core/array_ops.h>
|
||||
#include <turbomind/kernels/core/common.h>
|
||||
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
|
||||
using namespace turbomind;
|
||||
|
||||
template <class T, class Tacc, int block_dim, int vec_size>
|
||||
__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual, T* __restrict__ hidden_states,
|
||||
const T* __restrict__ weights, const T* __restrict__ bias, int dims, int num,
|
||||
float eps, float inv_dims) {
|
||||
const int ti = blockIdx.x;
|
||||
const int di = threadIdx.x * vec_size;
|
||||
|
||||
if (ti >= num) {
|
||||
return;
|
||||
}
|
||||
|
||||
residual += dims * ti;
|
||||
hidden_states += dims * ti;
|
||||
|
||||
Array<Tacc, vec_size> accum{};
|
||||
|
||||
Array<T, vec_size> r_vec;
|
||||
Array<T, vec_size> h_vec;
|
||||
Array<T, vec_size> b_vec;
|
||||
|
||||
for (int i = di; i < dims; i += block_dim * vec_size) {
|
||||
Load(r_vec, &residual[i]);
|
||||
Load(h_vec, &hidden_states[i]);
|
||||
|
||||
using namespace ops;
|
||||
r_vec = r_vec + h_vec;
|
||||
|
||||
if (bias) {
|
||||
Ldg(b_vec, &bias[i]);
|
||||
r_vec = r_vec + b_vec;
|
||||
}
|
||||
|
||||
Store(&residual[i], r_vec);
|
||||
|
||||
Array<Tacc, vec_size> tmp = cast<Tacc>(r_vec);
|
||||
|
||||
accum = accum + tmp * tmp;
|
||||
}
|
||||
|
||||
float sum{};
|
||||
PRAGMA_UNROLL
|
||||
for (int i = 0; i < vec_size; ++i) {
|
||||
sum += accum[i];
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<Tacc, block_dim>;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
sum = BlockReduce{temp_storage}.Sum(sum);
|
||||
|
||||
__shared__ float shared_sum;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
shared_sum = rsqrtf(sum * inv_dims + eps);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
sum = shared_sum;
|
||||
|
||||
Array<T, vec_size> w_vec;
|
||||
for (int i = di; i < dims; i += block_dim * vec_size) {
|
||||
Load(r_vec, &residual[i]);
|
||||
Ldg(w_vec, &weights[i]);
|
||||
PRAGMA_UNROLL
|
||||
for (int c = 0; c < vec_size; ++c) {
|
||||
r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c];
|
||||
}
|
||||
Store(&hidden_states[i], r_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void invokeBiasResidualRMSNorm(T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num,
|
||||
float eps, cudaStream_t st) {
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
constexpr int threads = 512;
|
||||
const int blocks = num;
|
||||
|
||||
BiasResidualRMSNormKernel<T, float, threads, vec_size>
|
||||
<<<blocks, threads, 0, st>>>(residual, hidden_states, weights, bias, dims, num, eps, 1.f / dims);
|
||||
}
|
||||
@@ -3,8 +3,7 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "utils.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
|
||||
@@ -3,28 +3,14 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#else
|
||||
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
||||
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
|
||||
|
||||
@@ -39,7 +25,6 @@
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
||||
// don't worry about overflow because num_experts is relatively small
|
||||
return row * total_col + col;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties,
|
||||
scalar_t* output, const int32_t numel) {
|
||||
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
|
||||
|
||||
const int32_t num_vec_elems = numel / vec_size;
|
||||
|
||||
#pragma unroll 1
|
||||
for (int32_t i = tid; i < num_vec_elems; i += stride) {
|
||||
vec_t logits_vec, penalties_vec, out_vec;
|
||||
logits_vec.cast_load(logits + i * vec_size);
|
||||
penalties_vec.cast_load(scaling_penalties + i * vec_size);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||
out_vec[j] = logits_vec[j] > scalar_t(0.0f) ? logits_vec[j] / penalties_vec[j] : logits_vec[j] * penalties_vec[j];
|
||||
}
|
||||
|
||||
out_vec.cast_store(output + i * vec_size);
|
||||
}
|
||||
|
||||
// process the remaining elements
|
||||
const int32_t start_idx = num_vec_elems * vec_size;
|
||||
for (int32_t i = start_idx + tid; i < numel; i += stride) {
|
||||
scalar_t logit = logits[i];
|
||||
scalar_t penalty = scaling_penalties[i];
|
||||
output[i] = logit > scalar_t(0.0f) ? logit / penalty : logit * penalty;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) {
|
||||
auto output = torch::empty_like(logits);
|
||||
const auto numel = logits.numel();
|
||||
const int threads = 512;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(logits.scalar_type(), scalar_t, [&] {
|
||||
uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size);
|
||||
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
static_cast<scalar_t*>(logits.data_ptr()), static_cast<scalar_t*>(scaling_penalties.data_ptr()),
|
||||
static_cast<scalar_t*>(output.data_ptr()), numel);
|
||||
return true;
|
||||
});
|
||||
|
||||
return output;
|
||||
}
|
||||
@@ -26,6 +26,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "trt_reduce_internal.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "trt_reduce_internal.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
using namespace trt_llm;
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <Python.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
@@ -36,9 +35,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
|
||||
|
||||
// sampling_scaling_penalties
|
||||
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties);
|
||||
|
||||
// int8_scaled_mm
|
||||
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
|
||||
@@ -17,12 +17,11 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace trt_llm {
|
||||
constexpr size_t WARP_SIZE = 32;
|
||||
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36;
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <pytorch_extension_utils.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "sgl_kernels_ops.h"
|
||||
|
||||
struct cuda_error : public std::runtime_error {
|
||||
/**
|
||||
* @brief Constructs a `cuda_error` object with the given `message`.
|
||||
|
||||
@@ -28,10 +28,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
|
||||
// sampling_scaling_penalties
|
||||
m.def("sampling_scaling_penalties(Tensor logits, Tensor scaling_penalties) -> Tensor");
|
||||
m.impl("sampling_scaling_penalties", torch::kCUDA, &sampling_scaling_penalties);
|
||||
|
||||
// int8_scaled_mm
|
||||
m.def(
|
||||
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import sampling_scaling_penalties
|
||||
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
|
||||
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
|
||||
dtypes = [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", batch_sizes)
|
||||
@pytest.mark.parametrize("vocab_size", vocab_sizes)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_sampling_scaling_penalties(batch_size, vocab_size, dtype):
|
||||
device = torch.device("cuda")
|
||||
rtol = 1e-3
|
||||
atol = 1e-3
|
||||
|
||||
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
|
||||
scaling_penalties = (
|
||||
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
|
||||
)
|
||||
|
||||
ref_output = torch.where(
|
||||
logits > 0, logits / scaling_penalties, logits * scaling_penalties
|
||||
)
|
||||
|
||||
kernel_output = sampling_scaling_penalties(logits, scaling_penalties)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"Failed for batch_size={batch_size}, vocab_size={vocab_size}, dtype={dtype}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user