[Feat] Update sgl-kernel flashinfer to latest main version (#5500)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -21,7 +21,8 @@ limitations under the License.
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
|
||||
void sgl_fused_add_rmsnorm(
|
||||
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(residual);
|
||||
CHECK_INPUT(weight);
|
||||
@@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
|
||||
static_cast<c_type*>(weight.data_ptr()),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
input.stride(0),
|
||||
residual.stride(0),
|
||||
eps,
|
||||
enable_pdl,
|
||||
torch_current_stream);
|
||||
TORCH_CHECK(
|
||||
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
|
||||
|
||||
Reference in New Issue
Block a user