From 6cdcbcc674542e58a441de4e40533bea522180c6 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Tue, 19 Aug 2025 01:16:08 +0800 Subject: [PATCH] [fix] fix enable_pdl for blackwell (#9011) --- sgl-kernel/python/sgl_kernel/elementwise.py | 10 +++++----- sgl-kernel/python/sgl_kernel/utils.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index f25cc0431..559d6ef39 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Optional import torch -from sgl_kernel.utils import get_cuda_stream, is_hopper_arch +from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer @@ -41,7 +41,7 @@ def rmsnorm( if out is None: out = torch.empty_like(input) if enable_pdl is None: - enable_pdl = is_hopper_arch() + enable_pdl = is_arch_support_pdl() torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl) return out @@ -77,7 +77,7 @@ def fused_add_rmsnorm( If None, will be automatically enabled on Hopper architecture. """ if enable_pdl is None: - enable_pdl = is_hopper_arch() + enable_pdl = is_arch_support_pdl() torch.ops.sgl_kernel.fused_add_rmsnorm.default( input, residual, weight, eps, enable_pdl ) @@ -117,7 +117,7 @@ def gemma_rmsnorm( if out is None: out = torch.empty_like(input) if enable_pdl is None: - enable_pdl = is_hopper_arch() + enable_pdl = is_arch_support_pdl() torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl) return out @@ -153,7 +153,7 @@ def gemma_fused_add_rmsnorm( If None, will be automatically enabled on Hopper architecture. """ if enable_pdl is None: - enable_pdl = is_hopper_arch() + enable_pdl = is_arch_support_pdl() torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default( input, residual, weight, eps, enable_pdl ) diff --git a/sgl-kernel/python/sgl_kernel/utils.py b/sgl-kernel/python/sgl_kernel/utils.py index 2960d3419..f2fa0b617 100644 --- a/sgl-kernel/python/sgl_kernel/utils.py +++ b/sgl-kernel/python/sgl_kernel/utils.py @@ -43,8 +43,8 @@ def _to_tensor_scalar_tuple(x): @functools.lru_cache(maxsize=1) -def is_hopper_arch() -> bool: +def is_arch_support_pdl() -> bool: # Hopper arch's compute capability == 9.0 device = torch.cuda.current_device() major, minor = torch.cuda.get_device_capability(device) - return major == 9 + return major >= 9