[Feat] Update sgl-kernel flashinfer to latest main version (#5500)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
PGFLMG
2025-04-18 03:43:23 +08:00
committed by GitHub
parent f13d65a7ea
commit c08a717c77
8 changed files with 393 additions and 133 deletions

View File

@@ -102,11 +102,11 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
/*
* From csrc/elementwise
*/
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
void gemma_fused_add_rmsnorm(
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
@@ -254,48 +254,38 @@ void segment_packbits(
*/
void min_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_min_p_arr,
double min_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);
void top_k_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_k_arr,
int64_t top_k_val,
int64_t cuda_stream);
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
void top_p_renorm_probs(
at::Tensor probs,
at::Tensor renorm_probs,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
int64_t cuda_stream);
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val);
void top_k_top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_k_arr,
double top_k_val,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);
void top_p_sampling_from_probs(
at::Tensor probs,
at::Tensor uniform_samples,
at::Tensor samples,
at::Tensor success,
at::Tensor output,
std::optional<at::Tensor> maybe_indices,
std::optional<at::Tensor> maybe_top_p_arr,
double top_p_val,
bool deterministic,
int64_t cuda_stream);
std::optional<at::Generator> gen);
namespace flash {
/*