[sgl-kernel] misc: update deepgemm version for sgl-kernel (#9340)
Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: fzyzcjy <ch271828n@outlook.com>
This commit is contained in:
@@ -248,7 +248,6 @@ class EPMoE(FusedMoE):
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
del gateup_input
|
||||
del gateup_input_fp8
|
||||
@@ -304,7 +303,6 @@ class EPMoE(FusedMoE):
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
del down_input
|
||||
del down_input_fp8
|
||||
@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE):
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8[0])
|
||||
|
||||
@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE):
|
||||
(
|
||||
down_input_scale
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
||||
down_input_scale
|
||||
)
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
||||
),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE):
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
|
||||
return down_output
|
||||
|
||||
@@ -1,26 +1,22 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
||||
DEEPGEMM_BLACKWELL,
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var, get_int_env_var
|
||||
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL:
|
||||
from deep_gemm import get_num_sms
|
||||
from deep_gemm.jit import build
|
||||
from deep_gemm.jit_kernels.gemm import get_best_configs
|
||||
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
import deep_gemm
|
||||
|
||||
|
||||
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||
@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
||||
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
||||
# NVRTC may have performance loss with some cases.
|
||||
# And NVCC JIT speed is also 9x faster in the ref commit
|
||||
_USE_NVRTC_DEFAULT = "0"
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
try:
|
||||
from deep_gemm.jit.compiler import get_nvcc_compiler
|
||||
|
||||
get_nvcc_compiler()
|
||||
except:
|
||||
logger.warning(
|
||||
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
||||
"and may have performance loss with some cases."
|
||||
)
|
||||
_USE_NVRTC_DEFAULT = "1"
|
||||
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
||||
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
|
||||
|
||||
|
||||
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||
@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||
# Default each rank will try compile all Ms to
|
||||
# load all symbols at the launch stages.
|
||||
# Avoid loading symbols at the serving stages.
|
||||
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
|
||||
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
|
||||
|
||||
|
||||
class DeepGemmKernelType(IntEnum):
|
||||
@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum):
|
||||
GEMM_NT_F8F8BF16 = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepGemmKernelHelper:
|
||||
name: str
|
||||
compile_func: Callable[
|
||||
[
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
],
|
||||
None,
|
||||
]
|
||||
configure_func: Callable[
|
||||
[int, int, int, int, int],
|
||||
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
]
|
||||
|
||||
|
||||
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
||||
|
||||
|
||||
# TODO improve naming
|
||||
def _compile_warning_1():
|
||||
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
||||
logger.warning(
|
||||
"Entering DeepGEMM JIT Pre-Compile session. "
|
||||
"It may takes a long time (typically 10-20 mins) "
|
||||
"if you have not run `sglang.compile_deep_gemm`. "
|
||||
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
||||
" for pre-compilation to reduce the overhead if you have not run it before. "
|
||||
"For example: "
|
||||
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
||||
)
|
||||
|
||||
|
||||
# TODO improve naming
|
||||
def _compile_warning_2():
|
||||
logger.warning(
|
||||
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
||||
"And it will makes inference throughput becomes flaky. "
|
||||
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
||||
" for pre-compilation to solve this issue. "
|
||||
"For example: "
|
||||
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
||||
)
|
||||
|
||||
|
||||
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
) -> None:
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
kwargs = {
|
||||
"GEMM_TYPE": GemmType.GroupedMasked,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": num_groups,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
) -> None:
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
kwargs = {
|
||||
"GEMM_TYPE": GemmType.GroupedContiguous,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": 1,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
def _compile_gemm_nt_f8f8bf16_one(
|
||||
n: int,
|
||||
k: int,
|
||||
_: int, # _ is a dummy parameter to align with other interfaces
|
||||
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
) -> None:
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
kwargs = {
|
||||
"GEMM_TYPE": GemmType.Normal,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": 1,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
# TODO further refactor warmup-related
|
||||
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
||||
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
|
||||
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
|
||||
m, n, k, num_groups, num_sms, is_grouped_masked=True
|
||||
),
|
||||
),
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
||||
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
|
||||
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
||||
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
||||
),
|
||||
),
|
||||
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
|
||||
name="gemm_fp8_fp8_bf16_nt",
|
||||
compile_func=_compile_gemm_nt_f8f8bf16_one,
|
||||
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
||||
m, n, k, 1, num_sms
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# TODO improve code
|
||||
def _maybe_compile_deep_gemm_one_type_all(
|
||||
kernel_type: DeepGemmKernelType,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
m_list: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
global _INITIALIZATION_DICT
|
||||
global _BUILTIN_M_LIST
|
||||
@@ -275,61 +89,145 @@ def _maybe_compile_deep_gemm_one_type_all(
|
||||
):
|
||||
_INITIALIZATION_DICT[query_key] = True
|
||||
|
||||
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
||||
_compile_warning_1()
|
||||
# TODO maybe improve logs
|
||||
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
||||
logger.warning(
|
||||
"Entering DeepGEMM JIT Pre-Compile session. "
|
||||
"It may takes a long time (typically 10-20 mins) "
|
||||
"if you have not run `sglang.compile_deep_gemm`. "
|
||||
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
||||
" for pre-compilation to reduce the overhead if you have not run it before. "
|
||||
"For example: "
|
||||
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Try DeepGEMM JIT Compiling for "
|
||||
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
||||
f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
||||
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
||||
)
|
||||
|
||||
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
||||
num_sms = get_num_sms()
|
||||
collected_configs = set()
|
||||
for m in m_list if m_list is not None else _BUILTIN_M_LIST:
|
||||
# Put config into set to get unique configs and reduce cases to be compiled
|
||||
collected_configs.add(
|
||||
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
|
||||
)
|
||||
compile_func = lambda config: kernel_helper.compile_func(
|
||||
n, k, num_groups, config
|
||||
_compile_deep_gemm_one_type_all(
|
||||
kernel_type=kernel_type,
|
||||
n=n,
|
||||
k=k,
|
||||
num_groups=num_groups,
|
||||
m_list=_BUILTIN_M_LIST,
|
||||
)
|
||||
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
||||
if _IN_PRECOMPILE_STAGE:
|
||||
yield
|
||||
return
|
||||
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
||||
def _compile_deep_gemm_one_type_all(
|
||||
kernel_type: DeepGemmKernelType,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
m_list: List[int],
|
||||
) -> None:
|
||||
if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
|
||||
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
|
||||
m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
|
||||
|
||||
from deep_gemm.jit.runtime import RuntimeCache
|
||||
executor = _BaseWarmupExecutor.create(
|
||||
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
||||
)
|
||||
|
||||
origin_func = RuntimeCache.get
|
||||
# TODO can use multi thread
|
||||
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
||||
executor.execute(m=m)
|
||||
|
||||
def __patched_func(self, *args, **kwargs):
|
||||
ret = origin_func(self, *args, **kwargs)
|
||||
if ret is None:
|
||||
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
||||
if not DEEPGEMM_BLACKWELL:
|
||||
_compile_warning_2()
|
||||
logger.warning(
|
||||
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
||||
)
|
||||
return ret
|
||||
|
||||
RuntimeCache.get = __patched_func
|
||||
yield
|
||||
RuntimeCache.get = origin_func
|
||||
class _BaseWarmupExecutor:
|
||||
@staticmethod
|
||||
def create(kernel_type: DeepGemmKernelType, **kwargs):
|
||||
return {
|
||||
DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
|
||||
}[kernel_type](**kwargs)
|
||||
|
||||
def execute(self, m):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _empty_token_fp8(size):
|
||||
*dims, k = size
|
||||
return (
|
||||
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _empty_block_fp8(size):
|
||||
*dims, n, k = size
|
||||
return (
|
||||
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_BLOCK_SIZE = 128
|
||||
|
||||
|
||||
class _NormalWarmupExecutor(_BaseWarmupExecutor):
|
||||
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
||||
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
||||
self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
|
||||
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
def execute(self, m):
|
||||
deep_gemm.fp8_gemm_nt(
|
||||
(self.lhs_q[:m], self.lhs_s[:m]),
|
||||
(self.rhs_q, self.rhs_s),
|
||||
self.out[:m],
|
||||
)
|
||||
|
||||
|
||||
class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
|
||||
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
||||
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
||||
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
||||
self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
|
||||
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
def execute(self, m):
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
||||
(self.lhs_q[:m], self.lhs_s[:m]),
|
||||
(self.rhs_q, self.rhs_s),
|
||||
self.out[:m],
|
||||
m_indices=self.m_indices[:m],
|
||||
)
|
||||
|
||||
|
||||
class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
|
||||
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
||||
self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
|
||||
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
||||
self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
|
||||
self.out = torch.empty(
|
||||
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
def execute(self, m):
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
||||
(self.lhs_q, self.lhs_s),
|
||||
(self.rhs_q, self.rhs_s),
|
||||
self.out,
|
||||
masked_m=self.masked_m,
|
||||
# DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
|
||||
expected_m=m,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deep_gemm_execution_hook(
|
||||
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
|
||||
):
|
||||
# not supported yet
|
||||
if not DEEPGEMM_BLACKWELL:
|
||||
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||
|
||||
with _log_jit_build(m, n, k, kernel_type):
|
||||
yield
|
||||
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||
yield
|
||||
|
||||
@@ -24,14 +24,12 @@ def _compute_enable_deep_gemm():
|
||||
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
||||
|
||||
|
||||
def _is_blackwell_arch() -> bool:
|
||||
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
||||
return major == 10
|
||||
|
||||
|
||||
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
||||
|
||||
try:
|
||||
from deep_gemm import fp8_gemm_nt
|
||||
|
||||
# They have not given a name to this breaking change
|
||||
DEEPGEMM_BLACKWELL = True
|
||||
except ImportError:
|
||||
DEEPGEMM_BLACKWELL = False
|
||||
|
||||
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
|
||||
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
||||
|
||||
@@ -16,33 +16,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
import deep_gemm
|
||||
|
||||
if DEEPGEMM_BLACKWELL:
|
||||
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
||||
from deep_gemm import (
|
||||
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||
)
|
||||
from deep_gemm import (
|
||||
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||
)
|
||||
else:
|
||||
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||
)
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||
)
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
|
||||
|
||||
# TODO maybe rename these functions
|
||||
def grouped_gemm_nt_f8f8bf16_masked(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
recipe=None,
|
||||
):
|
||||
num_groups, _, k = lhs[0].shape
|
||||
_, n, _ = rhs[0].shape
|
||||
@@ -51,13 +34,12 @@ def grouped_gemm_nt_f8f8bf16_masked(
|
||||
with compile_utils.deep_gemm_execution_hook(
|
||||
expected_m, n, k, num_groups, kernel_type
|
||||
):
|
||||
_grouped_gemm_nt_f8f8bf16_masked_raw(
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
masked_m,
|
||||
expected_m,
|
||||
**({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
|
||||
)
|
||||
|
||||
|
||||
@@ -72,7 +54,7 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||
_grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
||||
|
||||
|
||||
def gemm_nt_f8f8bf16(
|
||||
@@ -86,7 +68,7 @@ def gemm_nt_f8f8bf16(
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||
_gemm_nt_f8f8bf16_raw(
|
||||
deep_gemm.fp8_gemm_nt(
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
|
||||
@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw(
|
||||
)
|
||||
|
||||
if scale_ue8m0:
|
||||
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
assert group_size == 128
|
||||
x_s = transform_sf_into_required_layout(
|
||||
@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul(
|
||||
# scale_ue8m0=scale_ue8m0,
|
||||
# )
|
||||
|
||||
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
||||
|
||||
|
||||
@@ -459,7 +459,7 @@ def _requant_weight_ue8m0(
|
||||
import deep_gemm.utils.layout
|
||||
|
||||
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
||||
sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
|
||||
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
||||
return sf
|
||||
|
||||
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -24,7 +26,7 @@ class MXFP4QuantizeUtil:
|
||||
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
|
||||
|
||||
@classmethod
|
||||
def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
|
||||
def quantize(cls, input: torch.Tensor, block_size: Optional[int]) -> tuple:
|
||||
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
|
||||
Args:
|
||||
input (torch.Tensor): The input tensor to be quantized.
|
||||
|
||||
Reference in New Issue
Block a user