Cache the result of is_blackwell platform check (#10498)
This commit is contained in:
@@ -2,7 +2,7 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import get_bool_env_var, get_device_sm
|
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -21,12 +21,7 @@ def _compute_enable_deep_gemm():
|
|||||||
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
||||||
|
|
||||||
|
|
||||||
def _is_blackwell_arch() -> bool:
|
|
||||||
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
|
||||||
return major == 10
|
|
||||||
|
|
||||||
|
|
||||||
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
||||||
|
|
||||||
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
|
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell()
|
||||||
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
||||||
|
|||||||
@@ -167,6 +167,7 @@ is_ampere_with_cuda_12_3 = lambda: _check(8)
|
|||||||
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
def is_blackwell():
|
def is_blackwell():
|
||||||
if not is_cuda():
|
if not is_cuda():
|
||||||
return False
|
return False
|
||||||
|
|||||||
Reference in New Issue
Block a user