[Feat] Update sgl-kernel flashinfer to latest main version (#5500)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
|
||||
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
|
||||
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
|
||||
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
|
||||
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
|
||||
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
|
||||
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
|
||||
|
||||
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
|
||||
@@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
|
||||
|
||||
m.def(
|
||||
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
|
||||
"min_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
"min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float "
|
||||
"min_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
||||
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
|
||||
"cuda_stream) -> ()");
|
||||
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
|
||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
|
||||
"cuda_stream) -> ()");
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
|
||||
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
/*
|
||||
|
||||
@@ -21,7 +21,8 @@ limitations under the License.
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
|
||||
void sgl_fused_add_rmsnorm(
|
||||
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(residual);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
|
||||
static_cast<c_type*>(weight.data_ptr()),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
input.stride(0),
|
||||
residual.stride(0),
|
||||
eps,
|
||||
enable_pdl,
|
||||
torch_current_stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
|
||||
@@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
DType threshold_acc) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
|
||||
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
extern __shared__ __align__(alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
|
||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
|
||||
|
||||
DType prob_acc = 0.0;
|
||||
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
||||
@@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
|
||||
}
|
||||
|
||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
|
||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC>(
|
||||
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
|
||||
if (aggregate_relu_q_minus_p > u) {
|
||||
break;
|
||||
@@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
|
||||
|
||||
Reference in New Issue
Block a user