Set deepgemm to the default value in the hopper architecture. (#4613)
This commit is contained in:
@@ -26,11 +26,14 @@ from sglang.srt.utils import (
|
|||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_device_core_count,
|
get_device_core_count,
|
||||||
get_device_name,
|
get_device_name,
|
||||||
|
get_device_sm,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
supports_custom_op,
|
supports_custom_op,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_enable_jit_deepgemm = False
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||||
|
|
||||||
@@ -39,9 +42,12 @@ if _is_cuda:
|
|||||||
import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
|
import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
|
||||||
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
sm_version = get_device_sm()
|
||||||
|
if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
|
||||||
|
_enable_jit_deepgemm = True
|
||||||
|
|
||||||
_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
|
|
||||||
@@ -771,7 +777,7 @@ def w8a8_block_fp8_matmul(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# deepgemm only support bf16
|
# deepgemm only support bf16
|
||||||
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1006,6 +1006,13 @@ def get_amdgpu_memory_capacity():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_sm():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
return major * 10 + minor
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def get_nvgpu_memory_capacity():
|
def get_nvgpu_memory_capacity():
|
||||||
try:
|
try:
|
||||||
# Run nvidia-smi and capture the output
|
# Run nvidia-smi and capture the output
|
||||||
|
|||||||
Reference in New Issue
Block a user