diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index e25c7c333..2043805e7 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -26,11 +26,14 @@ from sglang.srt.utils import ( direct_register_custom_op, get_device_core_count, get_device_name, + get_device_sm, is_cuda, is_hip, supports_custom_op, ) +_enable_jit_deepgemm = False + _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn @@ -39,9 +42,12 @@ if _is_cuda: import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"` from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 -logger = logging.getLogger(__name__) + sm_version = get_device_sm() + if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")): + _enable_jit_deepgemm = True -_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0")) + +logger = logging.getLogger(__name__) if supports_custom_op(): @@ -771,7 +777,7 @@ def w8a8_block_fp8_matmul( ) # deepgemm only support bf16 - if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm: + if C.dtype == torch.bfloat16 and _enable_jit_deepgemm: if supports_custom_op(): torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) else: diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ddfba13f5..af2907f41 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1006,6 +1006,13 @@ def get_amdgpu_memory_capacity(): ) +def get_device_sm(): + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + return 0 + + def get_nvgpu_memory_capacity(): try: # Run nvidia-smi and capture the output