[Feat] Update sgl-kernel flashinfer to latest main version (#5500)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -11,17 +11,69 @@ def rmsnorm(
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
enable_pdl: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Root mean square normalization.
|
||||
|
||||
``out[i] = (input[i] / RMS(input)) * weight[i]``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
out: Optional[torch.Tensor]
|
||||
The output tensor, if specified, the kernel will update this tensor inplace.
|
||||
enable_pdl: bool
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: torch.Tensor
|
||||
Normalized tensor, shape (batch_size, hidden_size).
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
|
||||
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
enable_pdl: bool = False,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
|
||||
r"""Fused add root mean square normalization.
|
||||
|
||||
Step 1:
|
||||
``residual[i] += input[i]``
|
||||
|
||||
Step 2:
|
||||
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
residual: torch.Tensor
|
||||
Residual tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
enable_pdl: bool
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
"""
|
||||
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
|
||||
input, residual, weight, eps, enable_pdl
|
||||
)
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
@@ -29,20 +81,68 @@ def gemma_rmsnorm(
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
enable_pdl: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""Gemma-style root mean square normalization.
|
||||
|
||||
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
out: Optional[torch.Tensor]
|
||||
The output tensor, if specified, the kernel will update this tensor inplace.
|
||||
enable_pdl: bool
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: torch.Tensor
|
||||
Gemma Normalized tensor, shape (batch_size, hidden_size).
|
||||
"""
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm.default(
|
||||
out, input, weight, eps, get_cuda_stream()
|
||||
)
|
||||
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
enable_pdl: bool = False,
|
||||
) -> None:
|
||||
r"""Gemma-style fused add root mean square normalization.
|
||||
|
||||
Step 1:
|
||||
``residual[i] += input[i]``
|
||||
|
||||
Step 2:
|
||||
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: torch.Tensor
|
||||
Input tensor, shape (batch_size, hidden_size).
|
||||
residual: torch.Tensor
|
||||
Residual tensor, shape (batch_size, hidden_size).
|
||||
weight: torch.Tensor
|
||||
Weight tensor, shape (hidden_size,).
|
||||
eps: float
|
||||
Epsilon for numerical stability.
|
||||
enable_pdl: bool
|
||||
Whether to enable `programmatic dependent launch
|
||||
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
|
||||
"""
|
||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
||||
input, residual, weight, eps, get_cuda_stream()
|
||||
input, residual, weight, eps, enable_pdl
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user