[fix] fix enable_pdl for blackwell (#9011)
This commit is contained in:
@@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
|
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
|
||||||
|
|
||||||
|
|
||||||
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
|
||||||
@@ -41,7 +41,7 @@ def rmsnorm(
|
|||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
if enable_pdl is None:
|
if enable_pdl is None:
|
||||||
enable_pdl = is_hopper_arch()
|
enable_pdl = is_arch_support_pdl()
|
||||||
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
|
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ def fused_add_rmsnorm(
|
|||||||
If None, will be automatically enabled on Hopper architecture.
|
If None, will be automatically enabled on Hopper architecture.
|
||||||
"""
|
"""
|
||||||
if enable_pdl is None:
|
if enable_pdl is None:
|
||||||
enable_pdl = is_hopper_arch()
|
enable_pdl = is_arch_support_pdl()
|
||||||
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
|
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
|
||||||
input, residual, weight, eps, enable_pdl
|
input, residual, weight, eps, enable_pdl
|
||||||
)
|
)
|
||||||
@@ -117,7 +117,7 @@ def gemma_rmsnorm(
|
|||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
if enable_pdl is None:
|
if enable_pdl is None:
|
||||||
enable_pdl = is_hopper_arch()
|
enable_pdl = is_arch_support_pdl()
|
||||||
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
|
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@@ -153,7 +153,7 @@ def gemma_fused_add_rmsnorm(
|
|||||||
If None, will be automatically enabled on Hopper architecture.
|
If None, will be automatically enabled on Hopper architecture.
|
||||||
"""
|
"""
|
||||||
if enable_pdl is None:
|
if enable_pdl is None:
|
||||||
enable_pdl = is_hopper_arch()
|
enable_pdl = is_arch_support_pdl()
|
||||||
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
|
||||||
input, residual, weight, eps, enable_pdl
|
input, residual, weight, eps, enable_pdl
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ def _to_tensor_scalar_tuple(x):
|
|||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=1)
|
@functools.lru_cache(maxsize=1)
|
||||||
def is_hopper_arch() -> bool:
|
def is_arch_support_pdl() -> bool:
|
||||||
# Hopper arch's compute capability == 9.0
|
# Hopper arch's compute capability == 9.0
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
major, minor = torch.cuda.get_device_capability(device)
|
major, minor = torch.cuda.get_device_capability(device)
|
||||||
return major == 9
|
return major >= 9
|
||||||
|
|||||||
Reference in New Issue
Block a user