[sgl-kernel] misc: update deepgemm version for sgl-kernel (#9340)
Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: fzyzcjy <ch271828n@outlook.com>
This commit is contained in:
2
.github/workflows/pr-test-sgl-kernel.yml
vendored
2
.github/workflows/pr-test-sgl-kernel.yml
vendored
@@ -38,6 +38,8 @@ jobs:
|
||||
include:
|
||||
- python-version: "3.10"
|
||||
cuda-version: "12.4"
|
||||
- python-version: "3.10"
|
||||
cuda-version: "12.8"
|
||||
- python-version: "3.10"
|
||||
cuda-version: "12.9"
|
||||
name: Build Wheel (CUDA ${{ matrix.cuda-version }})
|
||||
|
||||
@@ -248,7 +248,6 @@ class EPMoE(FusedMoE):
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
del gateup_input
|
||||
del gateup_input_fp8
|
||||
@@ -304,7 +303,6 @@ class EPMoE(FusedMoE):
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
del down_input
|
||||
del down_input_fp8
|
||||
@@ -667,7 +665,6 @@ class DeepEPMoE(EPMoE):
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8[0])
|
||||
|
||||
@@ -708,9 +705,7 @@ class DeepEPMoE(EPMoE):
|
||||
(
|
||||
down_input_scale
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
|
||||
down_input_scale
|
||||
)
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
||||
),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
@@ -722,7 +717,6 @@ class DeepEPMoE(EPMoE):
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
|
||||
)
|
||||
|
||||
return down_output
|
||||
|
||||
@@ -1,26 +1,22 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
||||
DEEPGEMM_BLACKWELL,
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var, get_int_env_var
|
||||
from sglang.srt.utils import ceil_div, get_bool_env_var, get_int_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if ENABLE_JIT_DEEPGEMM and not DEEPGEMM_BLACKWELL:
|
||||
from deep_gemm import get_num_sms
|
||||
from deep_gemm.jit import build
|
||||
from deep_gemm.jit_kernels.gemm import get_best_configs
|
||||
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
import deep_gemm
|
||||
|
||||
|
||||
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||
@@ -40,19 +36,7 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
||||
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
||||
# NVRTC may have performance loss with some cases.
|
||||
# And NVCC JIT speed is also 9x faster in the ref commit
|
||||
_USE_NVRTC_DEFAULT = "0"
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
try:
|
||||
from deep_gemm.jit.compiler import get_nvcc_compiler
|
||||
|
||||
get_nvcc_compiler()
|
||||
except:
|
||||
logger.warning(
|
||||
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
||||
"and may have performance loss with some cases."
|
||||
)
|
||||
_USE_NVRTC_DEFAULT = "1"
|
||||
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
||||
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
|
||||
|
||||
|
||||
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||
@@ -75,7 +59,7 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||
# Default each rank will try compile all Ms to
|
||||
# load all symbols at the launch stages.
|
||||
# Avoid loading symbols at the serving stages.
|
||||
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
|
||||
_DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE
|
||||
|
||||
|
||||
class DeepGemmKernelType(IntEnum):
|
||||
@@ -84,185 +68,15 @@ class DeepGemmKernelType(IntEnum):
|
||||
GEMM_NT_F8F8BF16 = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepGemmKernelHelper:
|
||||
name: str
|
||||
compile_func: Callable[
|
||||
[
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
],
|
||||
None,
|
||||
]
|
||||
configure_func: Callable[
|
||||
[int, int, int, int, int],
|
||||
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
]
|
||||
|
||||
|
||||
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
||||
|
||||
|
||||
# TODO improve naming
|
||||
def _compile_warning_1():
|
||||
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
||||
logger.warning(
|
||||
"Entering DeepGEMM JIT Pre-Compile session. "
|
||||
"It may takes a long time (typically 10-20 mins) "
|
||||
"if you have not run `sglang.compile_deep_gemm`. "
|
||||
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
||||
" for pre-compilation to reduce the overhead if you have not run it before. "
|
||||
"For example: "
|
||||
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
||||
)
|
||||
|
||||
|
||||
# TODO improve naming
|
||||
def _compile_warning_2():
|
||||
logger.warning(
|
||||
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
||||
"And it will makes inference throughput becomes flaky. "
|
||||
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
||||
" for pre-compilation to solve this issue. "
|
||||
"For example: "
|
||||
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
||||
)
|
||||
|
||||
|
||||
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
) -> None:
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
|
||||
kwargs = {
|
||||
"GEMM_TYPE": GemmType.GroupedMasked,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": num_groups,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
) -> None:
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
kwargs = {
|
||||
"GEMM_TYPE": GemmType.GroupedContiguous,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": 1,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("m_grouped_gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
def _compile_gemm_nt_f8f8bf16_one(
|
||||
n: int,
|
||||
k: int,
|
||||
_: int, # _ is a dummy parameter to align with other interfaces
|
||||
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||
) -> None:
|
||||
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||
block_k = 128
|
||||
num_tma_threads = 128
|
||||
num_math_threads_per_group = 128
|
||||
kwargs = {
|
||||
"GEMM_TYPE": GemmType.Normal,
|
||||
"NUM_TMA_THREADS": num_tma_threads,
|
||||
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
||||
"N": n,
|
||||
"K": k,
|
||||
"NUM_GROUPS": 1,
|
||||
"BLOCK_M": block_m,
|
||||
"BLOCK_N": block_n,
|
||||
"BLOCK_K": block_k,
|
||||
"SWIZZLE_D_MODE": smem_config[1],
|
||||
"BLOCK_N_PADDING": smem_config[2],
|
||||
"NUM_STAGES": num_stages,
|
||||
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||
"NUM_SMS": num_sms,
|
||||
"SMEM_SIZE": smem_config[0],
|
||||
}
|
||||
|
||||
code = FP8GemmRuntime.generate(kwargs)
|
||||
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
# TODO further refactor warmup-related
|
||||
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
||||
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
|
||||
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
|
||||
m, n, k, num_groups, num_sms, is_grouped_masked=True
|
||||
),
|
||||
),
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
||||
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
|
||||
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
||||
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
||||
),
|
||||
),
|
||||
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
|
||||
name="gemm_fp8_fp8_bf16_nt",
|
||||
compile_func=_compile_gemm_nt_f8f8bf16_one,
|
||||
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
||||
m, n, k, 1, num_sms
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# TODO improve code
|
||||
def _maybe_compile_deep_gemm_one_type_all(
|
||||
kernel_type: DeepGemmKernelType,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
m_list: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
global _INITIALIZATION_DICT
|
||||
global _BUILTIN_M_LIST
|
||||
@@ -275,61 +89,145 @@ def _maybe_compile_deep_gemm_one_type_all(
|
||||
):
|
||||
_INITIALIZATION_DICT[query_key] = True
|
||||
|
||||
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
||||
_compile_warning_1()
|
||||
# TODO maybe improve logs
|
||||
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
||||
logger.warning(
|
||||
"Entering DeepGEMM JIT Pre-Compile session. "
|
||||
"It may takes a long time (typically 10-20 mins) "
|
||||
"if you have not run `sglang.compile_deep_gemm`. "
|
||||
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
||||
" for pre-compilation to reduce the overhead if you have not run it before. "
|
||||
"For example: "
|
||||
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Try DeepGEMM JIT Compiling for "
|
||||
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
||||
f"<{kernel_type.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
||||
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
||||
)
|
||||
|
||||
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
||||
num_sms = get_num_sms()
|
||||
collected_configs = set()
|
||||
for m in m_list if m_list is not None else _BUILTIN_M_LIST:
|
||||
# Put config into set to get unique configs and reduce cases to be compiled
|
||||
collected_configs.add(
|
||||
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
|
||||
)
|
||||
compile_func = lambda config: kernel_helper.compile_func(
|
||||
n, k, num_groups, config
|
||||
_compile_deep_gemm_one_type_all(
|
||||
kernel_type=kernel_type,
|
||||
n=n,
|
||||
k=k,
|
||||
num_groups=num_groups,
|
||||
m_list=_BUILTIN_M_LIST,
|
||||
)
|
||||
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
||||
if _IN_PRECOMPILE_STAGE:
|
||||
yield
|
||||
return
|
||||
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
||||
def _compile_deep_gemm_one_type_all(
|
||||
kernel_type: DeepGemmKernelType,
|
||||
n: int,
|
||||
k: int,
|
||||
num_groups: int,
|
||||
m_list: List[int],
|
||||
) -> None:
|
||||
if kernel_type == DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG:
|
||||
m_alignment = deep_gemm.get_mk_alignment_for_contiguous_layout()
|
||||
m_list = sorted(list(set(m for m in m_list if m % m_alignment == 0)))
|
||||
|
||||
from deep_gemm.jit.runtime import RuntimeCache
|
||||
executor = _BaseWarmupExecutor.create(
|
||||
kernel_type, max_m=max(m_list), n=n, k=k, num_groups=num_groups
|
||||
)
|
||||
|
||||
origin_func = RuntimeCache.get
|
||||
# TODO can use multi thread
|
||||
for m in tqdm(m_list, desc=f"DeepGEMM warmup"):
|
||||
executor.execute(m=m)
|
||||
|
||||
def __patched_func(self, *args, **kwargs):
|
||||
ret = origin_func(self, *args, **kwargs)
|
||||
if ret is None:
|
||||
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
||||
if not DEEPGEMM_BLACKWELL:
|
||||
_compile_warning_2()
|
||||
logger.warning(
|
||||
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
||||
)
|
||||
return ret
|
||||
|
||||
RuntimeCache.get = __patched_func
|
||||
yield
|
||||
RuntimeCache.get = origin_func
|
||||
class _BaseWarmupExecutor:
|
||||
@staticmethod
|
||||
def create(kernel_type: DeepGemmKernelType, **kwargs):
|
||||
return {
|
||||
DeepGemmKernelType.GEMM_NT_F8F8BF16: _NormalWarmupExecutor,
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: _GroupedContWarmupExecutor,
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: _GroupedMaskedWarmupExecutor,
|
||||
}[kernel_type](**kwargs)
|
||||
|
||||
def execute(self, m):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _empty_token_fp8(size):
|
||||
*dims, k = size
|
||||
return (
|
||||
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(*dims, ceil_div(k, _BLOCK_SIZE)), device="cuda", dtype=torch.float32
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _empty_block_fp8(size):
|
||||
*dims, n, k = size
|
||||
return (
|
||||
torch.empty(size, device="cuda", dtype=torch.float8_e4m3fn),
|
||||
torch.empty(
|
||||
(*dims, ceil_div(n, _BLOCK_SIZE), ceil_div(k, _BLOCK_SIZE)),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
_BLOCK_SIZE = 128
|
||||
|
||||
|
||||
class _NormalWarmupExecutor(_BaseWarmupExecutor):
|
||||
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
||||
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
||||
self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
|
||||
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
def execute(self, m):
|
||||
deep_gemm.fp8_gemm_nt(
|
||||
(self.lhs_q[:m], self.lhs_s[:m]),
|
||||
(self.rhs_q, self.rhs_s),
|
||||
self.out[:m],
|
||||
)
|
||||
|
||||
|
||||
class _GroupedContWarmupExecutor(_BaseWarmupExecutor):
|
||||
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
||||
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
|
||||
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
||||
self.m_indices = torch.zeros((max_m,), device="cuda", dtype=torch.int32)
|
||||
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
def execute(self, m):
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
|
||||
(self.lhs_q[:m], self.lhs_s[:m]),
|
||||
(self.rhs_q, self.rhs_s),
|
||||
self.out[:m],
|
||||
m_indices=self.m_indices[:m],
|
||||
)
|
||||
|
||||
|
||||
class _GroupedMaskedWarmupExecutor(_BaseWarmupExecutor):
|
||||
def __init__(self, max_m: int, n: int, k: int, num_groups: int):
|
||||
self.lhs_q, self.lhs_s = _empty_token_fp8((num_groups, max_m, k))
|
||||
self.rhs_q, self.rhs_s = _empty_block_fp8((num_groups, n, k))
|
||||
self.masked_m = torch.zeros((num_groups,), device="cuda", dtype=torch.int32)
|
||||
self.out = torch.empty(
|
||||
(num_groups, max_m, n), device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
def execute(self, m):
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
||||
(self.lhs_q, self.lhs_s),
|
||||
(self.rhs_q, self.rhs_s),
|
||||
self.out,
|
||||
masked_m=self.masked_m,
|
||||
# DeepGEMM uses `expect_m` instead of input shape for `get_best_config`
|
||||
expected_m=m,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def deep_gemm_execution_hook(
|
||||
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
|
||||
):
|
||||
# not supported yet
|
||||
if not DEEPGEMM_BLACKWELL:
|
||||
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||
|
||||
with _log_jit_build(m, n, k, kernel_type):
|
||||
yield
|
||||
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||
yield
|
||||
|
||||
@@ -24,14 +24,12 @@ def _compute_enable_deep_gemm():
|
||||
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
||||
|
||||
|
||||
def _is_blackwell_arch() -> bool:
|
||||
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
||||
return major == 10
|
||||
|
||||
|
||||
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
||||
|
||||
try:
|
||||
from deep_gemm import fp8_gemm_nt
|
||||
|
||||
# They have not given a name to this breaking change
|
||||
DEEPGEMM_BLACKWELL = True
|
||||
except ImportError:
|
||||
DEEPGEMM_BLACKWELL = False
|
||||
|
||||
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
|
||||
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
||||
|
||||
@@ -16,33 +16,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
import deep_gemm
|
||||
|
||||
if DEEPGEMM_BLACKWELL:
|
||||
from deep_gemm import fp8_gemm_nt as _gemm_nt_f8f8bf16_raw
|
||||
from deep_gemm import (
|
||||
fp8_m_grouped_gemm_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||
)
|
||||
from deep_gemm import (
|
||||
m_grouped_fp8_gemm_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||
)
|
||||
else:
|
||||
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||
)
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||
)
|
||||
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
|
||||
|
||||
|
||||
# TODO maybe rename these functions
|
||||
def grouped_gemm_nt_f8f8bf16_masked(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
recipe=None,
|
||||
):
|
||||
num_groups, _, k = lhs[0].shape
|
||||
_, n, _ = rhs[0].shape
|
||||
@@ -51,13 +34,12 @@ def grouped_gemm_nt_f8f8bf16_masked(
|
||||
with compile_utils.deep_gemm_execution_hook(
|
||||
expected_m, n, k, num_groups, kernel_type
|
||||
):
|
||||
_grouped_gemm_nt_f8f8bf16_masked_raw(
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
masked_m,
|
||||
expected_m,
|
||||
**({"recipe": recipe} if DEEPGEMM_BLACKWELL else {})
|
||||
)
|
||||
|
||||
|
||||
@@ -72,7 +54,7 @@ def grouped_gemm_nt_f8f8bf16_contig(
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||
_grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices)
|
||||
|
||||
|
||||
def gemm_nt_f8f8bf16(
|
||||
@@ -86,7 +68,7 @@ def gemm_nt_f8f8bf16(
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||
_gemm_nt_f8f8bf16_raw(
|
||||
deep_gemm.fp8_gemm_nt(
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
|
||||
@@ -298,7 +298,7 @@ def _per_token_group_quant_8bit_raw(
|
||||
)
|
||||
|
||||
if scale_ue8m0:
|
||||
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
assert group_size == 128
|
||||
x_s = transform_sf_into_required_layout(
|
||||
@@ -338,7 +338,7 @@ def _per_token_group_quant_8bit_fuse_silu_and_mul(
|
||||
# scale_ue8m0=scale_ue8m0,
|
||||
# )
|
||||
|
||||
from deep_gemm.utils.layout import transform_sf_into_required_layout
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
|
||||
|
||||
|
||||
@@ -459,7 +459,7 @@ def _requant_weight_ue8m0(
|
||||
import deep_gemm.utils.layout
|
||||
|
||||
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
||||
sf = deep_gemm.utils.layout.get_col_major_tma_aligned_packed_tensor(sf)
|
||||
sf = deep_gemm.utils.layout.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf)
|
||||
return sf
|
||||
|
||||
out_s = _transform_scale(out_s, mn=out_w.shape[-2])
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -24,7 +26,7 @@ class MXFP4QuantizeUtil:
|
||||
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
|
||||
|
||||
@classmethod
|
||||
def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
|
||||
def quantize(cls, input: torch.Tensor, block_size: Optional[int]) -> tuple:
|
||||
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
|
||||
Args:
|
||||
input (torch.Tensor): The input tensor to be quantized.
|
||||
|
||||
@@ -50,25 +50,17 @@ FetchContent_Declare(
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
|
||||
# DeepGEMM
|
||||
if("${CUDA_VERSION}" VERSION_EQUAL "12.8")
|
||||
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
|
||||
set(DeepGEMM_TAG "blackwell")
|
||||
elseif("${CUDA_VERSION}" VERSION_EQUAL "12.9")
|
||||
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
|
||||
set(DeepGEMM_TAG "blackwell")
|
||||
elseif("${CUDA_VERSION}" VERSION_EQUAL "13.0")
|
||||
set(DeepGEMM_REPO "https://github.com/sgl-project/DeepGEMM")
|
||||
set(DeepGEMM_TAG "blackwell")
|
||||
else()
|
||||
set(DeepGEMM_REPO "https://github.com/deepseek-ai/DeepGEMM")
|
||||
set(DeepGEMM_TAG "391755ada0ffefa9a6a52b6f14dcaf22d1a463e0")
|
||||
endif()
|
||||
FetchContent_Declare(
|
||||
repo-fmt
|
||||
GIT_REPOSITORY https://github.com/fmtlib/fmt
|
||||
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
|
||||
FetchContent_Declare(
|
||||
repo-deepgemm
|
||||
GIT_REPOSITORY ${DeepGEMM_REPO}
|
||||
GIT_TAG ${DeepGEMM_TAG}
|
||||
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
|
||||
GIT_TAG sgl
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-deepgemm)
|
||||
@@ -86,7 +78,7 @@ FetchContent_Populate(repo-triton)
|
||||
FetchContent_Declare(
|
||||
repo-flashinfer
|
||||
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
|
||||
GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
|
||||
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flashinfer)
|
||||
@@ -182,28 +174,11 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_100,code=sm_100"
|
||||
"-gencode=arch=compute_100a,code=sm_100a"
|
||||
"-gencode=arch=compute_103,code=sm_103"
|
||||
"-gencode=arch=compute_103a,code=sm_103a"
|
||||
"-gencode=arch=compute_101,code=sm_101"
|
||||
"-gencode=arch=compute_101a,code=sm_101a"
|
||||
"-gencode=arch=compute_120,code=sm_120"
|
||||
"-gencode=arch=compute_120a,code=sm_120a"
|
||||
)
|
||||
|
||||
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
|
||||
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_110,code=sm_110"
|
||||
"-gencode=arch=compute_110a,code=sm_110a"
|
||||
"-gencode=arch=compute_121,code=sm_121"
|
||||
"-gencode=arch=compute_121a,code=sm_121a"
|
||||
"--compress-mode=size"
|
||||
)
|
||||
else()
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-gencode=arch=compute_101,code=sm_101"
|
||||
"-gencode=arch=compute_101a,code=sm_101a"
|
||||
)
|
||||
endif()
|
||||
|
||||
else()
|
||||
list(APPEND SGL_KERNEL_CUDA_FLAGS
|
||||
"-use_fast_math"
|
||||
@@ -286,6 +261,12 @@ set(SOURCES
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
|
||||
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
|
||||
"csrc/moe/marlin_moe_wna16/ops.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
|
||||
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
|
||||
"csrc/moe/moe_align_kernel.cu"
|
||||
"csrc/moe/moe_fused_gate.cu"
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu"
|
||||
@@ -321,8 +302,6 @@ target_include_directories(common_ops PRIVATE
|
||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||
)
|
||||
set_source_files_properties("csrc/gemm/per_token_group_quant_8bit" PROPERTIES COMPILE_OPTIONS "--use_fast_math")
|
||||
|
||||
|
||||
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
||||
execute_process(
|
||||
@@ -464,13 +443,38 @@ install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
|
||||
set(DEEPGEMM_SOURCES
|
||||
"${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp"
|
||||
)
|
||||
# JIT Logic
|
||||
# DeepGEMM
|
||||
|
||||
install(DIRECTORY "${repo-deepgemm_SOURCE_DIR}/deep_gemm/"
|
||||
DESTINATION "deep_gemm"
|
||||
PATTERN ".git*" EXCLUDE
|
||||
PATTERN "__pycache__" EXCLUDE)
|
||||
Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES})
|
||||
|
||||
# Link against necessary libraries, including nvrtc for JIT compilation.
|
||||
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static)
|
||||
|
||||
# Add include directories needed by DeepGEMM.
|
||||
target_include_directories(deep_gemm_cpp PRIVATE
|
||||
${repo-deepgemm_SOURCE_DIR}/deep_gemm/include
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-fmt_SOURCE_DIR}/include
|
||||
)
|
||||
|
||||
# Apply the same compile options as common_ops.
|
||||
target_compile_options(deep_gemm_cpp PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>)
|
||||
|
||||
# Create an empty __init__.py to make `deepgemm` a Python package.
|
||||
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "")
|
||||
install(
|
||||
FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py
|
||||
DESTINATION deep_gemm
|
||||
RENAME __init__.py
|
||||
)
|
||||
|
||||
# Install the compiled DeepGEMM API library.
|
||||
install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm)
|
||||
|
||||
# Install the source files required by DeepGEMM for runtime JIT compilation.
|
||||
install(
|
||||
DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/
|
||||
DESTINATION deep_gemm
|
||||
)
|
||||
|
||||
install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/"
|
||||
DESTINATION "deep_gemm/include/cute")
|
||||
|
||||
@@ -9,7 +9,6 @@ import jinja2
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -34,17 +33,6 @@ TEMPLATE = (
|
||||
"( MARLIN_KERNEL_PARAMS );"
|
||||
)
|
||||
|
||||
KERNEL_FILE_TEMPLATE = (
|
||||
"// auto generated by generate.py\n"
|
||||
"// clang-format off\n"
|
||||
"#pragma once\n\n"
|
||||
"{% for kernel_file in kernel_files %}"
|
||||
'#include "{{ kernel_file }}"\n'
|
||||
"{% endfor %}"
|
||||
)
|
||||
|
||||
KERNEL_FILE_NAME = "kernel_marlin.cuh"
|
||||
|
||||
# int8 with zero point case (sglang::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
|
||||
@@ -60,12 +48,11 @@ DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
kernel_files = set()
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
@@ -108,20 +95,10 @@ def generate_new_kernels():
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"
|
||||
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
kernel_files.add(filename)
|
||||
|
||||
kernel_files = list(kernel_files)
|
||||
kernel_files.sort()
|
||||
|
||||
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
|
||||
kernel_files=kernel_files
|
||||
)
|
||||
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,6 +1,5 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
@@ -1,10 +0,0 @@
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "kernel_bf16_ku4.cuh"
|
||||
#include "kernel_bf16_ku4b8.cuh"
|
||||
#include "kernel_bf16_ku8b128.cuh"
|
||||
#include "kernel_fp16_ku4.cuh"
|
||||
#include "kernel_fp16_ku4b8.cuh"
|
||||
#include "kernel_fp16_ku8b128.cuh"
|
||||
@@ -18,8 +18,6 @@
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "kernel_marlin.cuh"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert( \
|
||||
|
||||
@@ -23,7 +23,6 @@ limitations under the License.
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cuda/functional>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/util_type.hpp>
|
||||
@@ -34,16 +33,6 @@ limitations under the License.
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
// Define reduction operators based on CUDA version
|
||||
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
|
||||
#if CUDA_VERSION >= 12090
|
||||
using MaxReduceOp = cuda::maximum<>;
|
||||
using MinReduceOp = cuda::minimum<>;
|
||||
#else
|
||||
using MaxReduceOp = cub::Max;
|
||||
using MinReduceOp = cub::Min;
|
||||
#endif
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
@@ -83,6 +72,7 @@ __launch_bounds__(TPB) __global__
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
@@ -95,7 +85,7 @@ __launch_bounds__(TPB) __global__
|
||||
threadData = max(convert_to_float<T>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
@@ -109,7 +99,7 @@ __launch_bounds__(TPB) __global__
|
||||
threadData += exp((convert_to_float<T>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Sum(threadData);
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
|
||||
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.3.6.post2"
|
||||
version = "0.3.7"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.3.6.post2"
|
||||
version = "0.3.7"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.3.6.post2"
|
||||
version = "0.3.7"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.3.6.post2"
|
||||
__version__ = "0.3.7"
|
||||
|
||||
Reference in New Issue
Block a user