[Feat] Update sgl-kernel flashinfer to latest main version (#5500)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
PGFLMG
2025-04-18 03:43:23 +08:00
committed by GitHub
parent f13d65a7ea
commit c08a717c77
8 changed files with 393 additions and 133 deletions

View File

@@ -21,7 +21,8 @@ limitations under the License.
using namespace flashinfer;
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
void sgl_fused_add_rmsnorm(
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
@@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
static_cast<c_type*>(weight.data_ptr()),
batch_size,
hidden_size,
input.stride(0),
residual.stride(0),
eps,
enable_pdl,
torch_current_stream);
TORCH_CHECK(
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));