|
|
|
|
@@ -16,11 +16,7 @@ if is_cuda():
|
|
|
|
|
import deep_gemm
|
|
|
|
|
from deep_gemm import get_num_sms
|
|
|
|
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
|
|
|
|
from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
|
|
|
|
|
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
|
|
|
|
|
from deep_gemm.jit_kernels.m_grouped_gemm import (
|
|
|
|
|
template as deep_gemm_grouped_gemm_template,
|
|
|
|
|
)
|
|
|
|
|
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
|
|
|
|
from deep_gemm.jit_kernels.tuner import jit_tuner
|
|
|
|
|
|
|
|
|
|
sm_version = get_device_sm()
|
|
|
|
|
@@ -45,10 +41,15 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
|
|
|
|
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
|
|
|
|
|
|
|
|
|
# Force redirect deep_gemm cache_dir
|
|
|
|
|
os.environ["DG_CACHE_DIR"] = os.getenv(
|
|
|
|
|
"SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
|
|
|
|
|
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
|
|
|
|
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
|
|
|
|
# NVRTC may have performance loss with some cases.
|
|
|
|
|
# And NVCC JIT speed is also 9x faster in the ref commit
|
|
|
|
|
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
|
|
|
|
global _BUILTIN_M_LIST
|
|
|
|
|
@@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
|
|
|
num_groups: int,
|
|
|
|
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
|
|
|
|
) -> None:
|
|
|
|
|
# Auto-tuning with compilation
|
|
|
|
|
global deep_gemm_includes, deep_gemm_grouped_gemm_template
|
|
|
|
|
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
|
|
|
|
_ = jit_tuner.compile_and_tune(
|
|
|
|
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
|
|
|
|
block_k = 128
|
|
|
|
|
num_tma_threads = 128
|
|
|
|
|
num_math_threads_per_group = 128
|
|
|
|
|
kwargs = {
|
|
|
|
|
"NUM_TMA_THREADS": num_tma_threads,
|
|
|
|
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
|
|
|
|
"BLOCK_K": block_k,
|
|
|
|
|
"NUM_SMS": num_sms,
|
|
|
|
|
"SMEM_SIZE": smem_config[0],
|
|
|
|
|
}
|
|
|
|
|
_, _ = jit_tuner.compile_and_tune(
|
|
|
|
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
|
|
|
|
keys={
|
|
|
|
|
"N": n,
|
|
|
|
|
@@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
|
|
|
"NUM_STAGES": num_stages,
|
|
|
|
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
|
|
|
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
|
|
|
|
"GEMM_TYPE": "GroupedMasked",
|
|
|
|
|
"GEMM_TYPE": GemmType.GroupedMasked,
|
|
|
|
|
},
|
|
|
|
|
space=(),
|
|
|
|
|
includes=deep_gemm_includes,
|
|
|
|
|
arg_defs=(
|
|
|
|
|
("lhs", torch.float8_e4m3fn),
|
|
|
|
|
("lhs_scales", torch.float),
|
|
|
|
|
("rhs", torch.float8_e4m3fn),
|
|
|
|
|
("rhs_scales", torch.float),
|
|
|
|
|
("out", torch.bfloat16),
|
|
|
|
|
("grouped_layout", torch.int32),
|
|
|
|
|
("m", int),
|
|
|
|
|
("stream", torch.cuda.Stream),
|
|
|
|
|
("num_sms", int),
|
|
|
|
|
("smem_size", int),
|
|
|
|
|
),
|
|
|
|
|
template=deep_gemm_grouped_gemm_template,
|
|
|
|
|
args=[],
|
|
|
|
|
kwargs=kwargs,
|
|
|
|
|
runtime_cls=FP8GemmRuntime,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|
|
|
|
num_groups: int,
|
|
|
|
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
|
|
|
|
) -> None:
|
|
|
|
|
global deep_gemm_includes, deep_gemm_grouped_gemm_template
|
|
|
|
|
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
|
|
|
|
_ = jit_tuner.compile_and_tune(
|
|
|
|
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
|
|
|
|
block_k = 128
|
|
|
|
|
num_tma_threads = 128
|
|
|
|
|
num_math_threads_per_group = 128
|
|
|
|
|
kwargs = {
|
|
|
|
|
"NUM_TMA_THREADS": num_tma_threads,
|
|
|
|
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
|
|
|
|
"BLOCK_K": block_k,
|
|
|
|
|
"NUM_SMS": num_sms,
|
|
|
|
|
"SMEM_SIZE": smem_config[0],
|
|
|
|
|
}
|
|
|
|
|
_, _ = jit_tuner.compile_and_tune(
|
|
|
|
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
|
|
|
|
keys={
|
|
|
|
|
"N": n,
|
|
|
|
|
@@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|
|
|
|
"NUM_STAGES": num_stages,
|
|
|
|
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
|
|
|
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
|
|
|
|
"GEMM_TYPE": "GroupedContiguous",
|
|
|
|
|
"GEMM_TYPE": GemmType.GroupedContiguous,
|
|
|
|
|
},
|
|
|
|
|
space=(),
|
|
|
|
|
includes=deep_gemm_includes,
|
|
|
|
|
arg_defs=(
|
|
|
|
|
("lhs", torch.float8_e4m3fn),
|
|
|
|
|
("lhs_scales", torch.float),
|
|
|
|
|
("rhs", torch.float8_e4m3fn),
|
|
|
|
|
("rhs_scales", torch.float),
|
|
|
|
|
("out", torch.bfloat16),
|
|
|
|
|
("grouped_layout", torch.int32),
|
|
|
|
|
("m", int),
|
|
|
|
|
("num_groups", int),
|
|
|
|
|
("stream", torch.cuda.Stream),
|
|
|
|
|
("num_sms", int),
|
|
|
|
|
("smem_size", int),
|
|
|
|
|
),
|
|
|
|
|
template=deep_gemm_grouped_gemm_template,
|
|
|
|
|
args=[],
|
|
|
|
|
kwargs=kwargs,
|
|
|
|
|
runtime_cls=FP8GemmRuntime,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
|
|
|
_: int, # _ is a dummy parameter to align with other interfaces
|
|
|
|
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
|
|
|
|
) -> None:
|
|
|
|
|
global deep_gemm_includes, deep_gemm_gemm_template
|
|
|
|
|
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
|
|
|
|
_ = jit_tuner.compile_and_tune(
|
|
|
|
|
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
|
|
|
|
block_k = 128
|
|
|
|
|
num_tma_threads = 128
|
|
|
|
|
num_math_threads_per_group = 128
|
|
|
|
|
kwargs = {
|
|
|
|
|
"GEMM_TYPE": GemmType.Normal,
|
|
|
|
|
"NUM_TMA_THREADS": num_tma_threads,
|
|
|
|
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
|
|
|
|
"NUM_GROUPS": 1,
|
|
|
|
|
"BLOCK_K": block_k,
|
|
|
|
|
"NUM_SMS": num_sms,
|
|
|
|
|
"SMEM_SIZE": smem_config[0],
|
|
|
|
|
}
|
|
|
|
|
_, _ = jit_tuner.compile_and_tune(
|
|
|
|
|
name="gemm_fp8_fp8_bf16_nt",
|
|
|
|
|
keys={
|
|
|
|
|
"N": n,
|
|
|
|
|
@@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
|
|
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
|
|
|
|
},
|
|
|
|
|
space=(),
|
|
|
|
|
includes=deep_gemm_includes,
|
|
|
|
|
arg_defs=(
|
|
|
|
|
("lhs", torch.float8_e4m3fn),
|
|
|
|
|
("lhs_scales", torch.float),
|
|
|
|
|
("rhs", torch.float8_e4m3fn),
|
|
|
|
|
("rhs_scales", torch.float),
|
|
|
|
|
("out", torch.bfloat16),
|
|
|
|
|
("m", int),
|
|
|
|
|
("stream", torch.cuda.Stream),
|
|
|
|
|
("num_sms", int),
|
|
|
|
|
("smem_size", int),
|
|
|
|
|
),
|
|
|
|
|
template=deep_gemm_gemm_template,
|
|
|
|
|
args=[],
|
|
|
|
|
kwargs=kwargs,
|
|
|
|
|
runtime_cls=FP8GemmRuntime,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
|
|
|
|
|
|
|
|
from deep_gemm.jit.runtime import RuntimeCache
|
|
|
|
|
|
|
|
|
|
origin_func = RuntimeCache.__getitem__
|
|
|
|
|
origin_func = RuntimeCache.get
|
|
|
|
|
|
|
|
|
|
def __patched_func(self, *args, **kwargs):
|
|
|
|
|
ret = origin_func(self, *args, **kwargs)
|
|
|
|
|
@@ -385,6 +375,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
|
|
|
)
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
RuntimeCache.__getitem__ = __patched_func
|
|
|
|
|
RuntimeCache.get = __patched_func
|
|
|
|
|
yield
|
|
|
|
|
RuntimeCache.__getitem__ = origin_func
|
|
|
|
|
RuntimeCache.get = origin_func
|
|
|
|
|
|