[Feat] Enable PDL automatically on Hopper architecture (#5981)

This commit is contained in:
Huapeng Zhou
2025-06-01 12:30:17 -07:00
committed by GitHub
parent c6a0cacc35
commit 2f7420bc84
2 changed files with 28 additions and 9 deletions

View File

@@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
import torch import torch
from sgl_kernel.utils import get_cuda_stream from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
@@ -11,7 +11,7 @@ def rmsnorm(
weight: torch.Tensor, weight: torch.Tensor,
eps: float = 1e-6, eps: float = 1e-6,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
enable_pdl: bool = False, enable_pdl: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
r"""Root mean square normalization. r"""Root mean square normalization.
@@ -27,9 +27,10 @@ def rmsnorm(
Epsilon for numerical stability. Epsilon for numerical stability.
out: Optional[torch.Tensor] out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace. The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_ <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
Returns Returns
------- -------
@@ -38,6 +39,8 @@ def rmsnorm(
""" """
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_hopper_arch()
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl) torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out return out
@@ -47,7 +50,7 @@ def fused_add_rmsnorm(
residual: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
eps: float = 1e-6, eps: float = 1e-6,
enable_pdl: bool = False, enable_pdl: Optional[bool] = None,
) -> None: ) -> None:
r"""Fused add root mean square normalization. r"""Fused add root mean square normalization.
@@ -67,10 +70,13 @@ def fused_add_rmsnorm(
Weight tensor, shape (hidden_size,). Weight tensor, shape (hidden_size,).
eps: float eps: float
Epsilon for numerical stability. Epsilon for numerical stability.
enable_pdl: bool enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_ <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
""" """
if enable_pdl is None:
enable_pdl = is_hopper_arch()
torch.ops.sgl_kernel.fused_add_rmsnorm.default( torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl input, residual, weight, eps, enable_pdl
) )
@@ -81,7 +87,7 @@ def gemma_rmsnorm(
weight: torch.Tensor, weight: torch.Tensor,
eps: float = 1e-6, eps: float = 1e-6,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
enable_pdl: bool = False, enable_pdl: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
r"""Gemma-style root mean square normalization. r"""Gemma-style root mean square normalization.
@@ -97,9 +103,10 @@ def gemma_rmsnorm(
Epsilon for numerical stability. Epsilon for numerical stability.
out: Optional[torch.Tensor] out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace. The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_ <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
Returns Returns
------- -------
@@ -108,6 +115,8 @@ def gemma_rmsnorm(
""" """
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_hopper_arch()
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl) torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out return out
@@ -117,7 +126,7 @@ def gemma_fused_add_rmsnorm(
residual: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
eps: float = 1e-6, eps: float = 1e-6,
enable_pdl: bool = False, enable_pdl: Optional[bool] = None,
) -> None: ) -> None:
r"""Gemma-style fused add root mean square normalization. r"""Gemma-style fused add root mean square normalization.
@@ -137,10 +146,13 @@ def gemma_fused_add_rmsnorm(
Weight tensor, shape (hidden_size,). Weight tensor, shape (hidden_size,).
eps: float eps: float
Epsilon for numerical stability. Epsilon for numerical stability.
enable_pdl: bool enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_ <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
""" """
if enable_pdl is None:
enable_pdl = is_hopper_arch()
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default( torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl input, residual, weight, eps, enable_pdl
) )

View File

@@ -39,3 +39,10 @@ def _to_tensor_scalar_tuple(x):
return (x, 0) return (x, 0)
else: else:
return (None, x) return (None, x)
def is_hopper_arch() -> bool:
# Hopper arch's compute capability == 9.0
device = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device)
return major == 9