[feature] enable pre compile jit deep_gemm (#5580)
This commit is contained in:
136
python/sglang/compile_deep_gemm.py
Normal file
136
python/sglang/compile_deep_gemm.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Compile DeepGEMM Kernels for a model with specify server arguments
|
||||||
|
|
||||||
|
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
|
||||||
|
It accepts server arguments (the same as launch_server.py).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.srt.warmup import warmup
|
||||||
|
|
||||||
|
multiprocessing.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
# Reduce warning
|
||||||
|
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CompileArgs:
|
||||||
|
timeout: int = 3600
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
# use the default value's type to cast the args into correct types.
|
||||||
|
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||||
|
return cls(
|
||||||
|
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@warmup("compile-deep-gemm")
|
||||||
|
async def warm_up_compile(tokenizer_manager: TokenizerManager):
|
||||||
|
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
||||||
|
generate_req_input = GenerateReqInput(
|
||||||
|
input_ids=[0, 1, 2, 3],
|
||||||
|
sampling_params={
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
"ignore_eos": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
||||||
|
|
||||||
|
|
||||||
|
def launch_server_internal(server_args):
|
||||||
|
try:
|
||||||
|
launch_server(server_args)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
kill_process_tree(os.getpid(), include_parent=False)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_server_process_and_send_one_request(
|
||||||
|
server_args: ServerArgs, compile_args: CompileArgs
|
||||||
|
):
|
||||||
|
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
||||||
|
proc.start()
|
||||||
|
base_url = f"http://{server_args.host}:{server_args.port}"
|
||||||
|
timeout = compile_args.timeout
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
}
|
||||||
|
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return proc
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(10)
|
||||||
|
raise TimeoutError(
|
||||||
|
"DeepGEMM Kernels compilation timeout."
|
||||||
|
"\n\nFeel free and please restart the command."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
||||||
|
# Disbale cuda graph and torch compile to save time
|
||||||
|
server_args.disable_cuda_graph = True
|
||||||
|
server_args.enable_torch_compile = False
|
||||||
|
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
||||||
|
|
||||||
|
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
|
||||||
|
server_args.watchdog_timeout = compile_args.timeout
|
||||||
|
server_args.warmups = "compile-deep-gemm"
|
||||||
|
|
||||||
|
|
||||||
|
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
||||||
|
print(
|
||||||
|
"Begin DeepGEMM Kernels compilation...\n"
|
||||||
|
"It may take a long time and timeout maybe raised "
|
||||||
|
"while the compilation is still in progress.\n"
|
||||||
|
"Just feel free to restart the command "
|
||||||
|
"until the compilation is fully finished.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
||||||
|
|
||||||
|
kill_process_tree(proc.pid)
|
||||||
|
|
||||||
|
print("\nDeepGEMM Kernels compilation finished successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
ServerArgs.add_cli_args(parser)
|
||||||
|
CompileArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
server_args = ServerArgs.from_cli_args(args)
|
||||||
|
compile_args = CompileArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
refine_server_args(server_args, compile_args)
|
||||||
|
|
||||||
|
run_compile(server_args, compile_args)
|
||||||
378
python/sglang/srt/layers/quantization/deep_gemm.py
Normal file
378
python/sglang/srt/layers/quantization/deep_gemm.py
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import IntEnum, auto
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
|
||||||
|
|
||||||
|
_ENABLE_JIT_DEEPGEMM = False
|
||||||
|
if is_cuda():
|
||||||
|
import deep_gemm
|
||||||
|
from deep_gemm import get_num_sms
|
||||||
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
||||||
|
from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
|
||||||
|
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
|
||||||
|
from deep_gemm.jit_kernels.m_grouped_gemm import (
|
||||||
|
template as deep_gemm_grouped_gemm_template,
|
||||||
|
)
|
||||||
|
from deep_gemm.jit_kernels.tuner import jit_tuner
|
||||||
|
|
||||||
|
sm_version = get_device_sm()
|
||||||
|
if sm_version == 90:
|
||||||
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
|
||||||
|
_ENABLE_JIT_DEEPGEMM = True
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||||
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
||||||
|
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
||||||
|
)
|
||||||
|
_DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
|
||||||
|
_COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
||||||
|
_IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
|
||||||
|
|
||||||
|
# Force redirect deep_gemm cache_dir
|
||||||
|
os.environ["DG_CACHE_DIR"] = os.getenv(
|
||||||
|
"SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||||
|
global _BUILTIN_M_LIST
|
||||||
|
global _DO_COMPILE
|
||||||
|
|
||||||
|
# Generate m_max
|
||||||
|
m_max = 1024 * 16
|
||||||
|
if server_args.chunked_prefill_size < 1:
|
||||||
|
m_max = 1024 * 64
|
||||||
|
elif server_args.chunked_prefill_size > 8192:
|
||||||
|
m_max = server_args.chunked_prefill_size * 2
|
||||||
|
m_max = min(1024 * 128, m_max)
|
||||||
|
_BUILTIN_M_LIST = list(range(1, m_max + 1))
|
||||||
|
|
||||||
|
# Check if is the first rank on node
|
||||||
|
_DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
|
||||||
|
|
||||||
|
|
||||||
|
class DeepGemmKernelType(IntEnum):
|
||||||
|
GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
|
||||||
|
GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
|
||||||
|
GEMM_NT_F8F8BF16 = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepGemmKernelHelper:
|
||||||
|
name: str
|
||||||
|
compile_func: Callable[
|
||||||
|
[
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
configure_func: Callable[
|
||||||
|
[int, int, int, int, int],
|
||||||
|
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_warning_1():
|
||||||
|
if not _IN_PRE_COMPILE_STAGE:
|
||||||
|
logger.warning(
|
||||||
|
"Entering DeepGEMM JIT Pre-Complie session. "
|
||||||
|
"And it may takes a long time(Typically 10-20 mins) "
|
||||||
|
"if you have not run `sglang.compile_deep_gemm`. "
|
||||||
|
"Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
||||||
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
||||||
|
"For example: "
|
||||||
|
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_warning_2():
|
||||||
|
logger.warning(
|
||||||
|
"Entering DeepGEMM JIT Single Kernel Complie session. "
|
||||||
|
"And it will makes inference throughput becomes flaky. "
|
||||||
|
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
||||||
|
" for pre-compilation to solve this issue. "
|
||||||
|
"For example: "
|
||||||
|
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
num_groups: int,
|
||||||
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||||
|
) -> None:
|
||||||
|
# Auto-tuning with compilation
|
||||||
|
global deep_gemm_includes, deep_gemm_grouped_gemm_template
|
||||||
|
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||||
|
_ = jit_tuner.compile_and_tune(
|
||||||
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
||||||
|
keys={
|
||||||
|
"N": n,
|
||||||
|
"K": k,
|
||||||
|
"BLOCK_M": block_m,
|
||||||
|
"BLOCK_N": block_n,
|
||||||
|
"SWIZZLE_D_MODE": smem_config[1],
|
||||||
|
"BLOCK_N_PADDING": smem_config[2],
|
||||||
|
"NUM_GROUPS": num_groups,
|
||||||
|
"NUM_STAGES": num_stages,
|
||||||
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||||
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||||
|
"GEMM_TYPE": "GroupedMasked",
|
||||||
|
},
|
||||||
|
space=(),
|
||||||
|
includes=deep_gemm_includes,
|
||||||
|
arg_defs=(
|
||||||
|
("lhs", torch.float8_e4m3fn),
|
||||||
|
("lhs_scales", torch.float),
|
||||||
|
("rhs", torch.float8_e4m3fn),
|
||||||
|
("rhs_scales", torch.float),
|
||||||
|
("out", torch.bfloat16),
|
||||||
|
("grouped_layout", torch.int32),
|
||||||
|
("m", int),
|
||||||
|
("stream", torch.cuda.Stream),
|
||||||
|
("num_sms", int),
|
||||||
|
("smem_size", int),
|
||||||
|
),
|
||||||
|
template=deep_gemm_grouped_gemm_template,
|
||||||
|
args=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
num_groups: int,
|
||||||
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||||
|
) -> None:
|
||||||
|
global deep_gemm_includes, deep_gemm_grouped_gemm_template
|
||||||
|
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||||
|
_ = jit_tuner.compile_and_tune(
|
||||||
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
||||||
|
keys={
|
||||||
|
"N": n,
|
||||||
|
"K": k,
|
||||||
|
"BLOCK_M": block_m,
|
||||||
|
"BLOCK_N": block_n,
|
||||||
|
"SWIZZLE_D_MODE": smem_config[1],
|
||||||
|
"BLOCK_N_PADDING": smem_config[2],
|
||||||
|
"NUM_GROUPS": num_groups,
|
||||||
|
"NUM_STAGES": num_stages,
|
||||||
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||||
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||||
|
"GEMM_TYPE": "GroupedContiguous",
|
||||||
|
},
|
||||||
|
space=(),
|
||||||
|
includes=deep_gemm_includes,
|
||||||
|
arg_defs=(
|
||||||
|
("lhs", torch.float8_e4m3fn),
|
||||||
|
("lhs_scales", torch.float),
|
||||||
|
("rhs", torch.float8_e4m3fn),
|
||||||
|
("rhs_scales", torch.float),
|
||||||
|
("out", torch.bfloat16),
|
||||||
|
("grouped_layout", torch.int32),
|
||||||
|
("m", int),
|
||||||
|
("num_groups", int),
|
||||||
|
("stream", torch.cuda.Stream),
|
||||||
|
("num_sms", int),
|
||||||
|
("smem_size", int),
|
||||||
|
),
|
||||||
|
template=deep_gemm_grouped_gemm_template,
|
||||||
|
args=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compile_gemm_nt_f8f8bf16_one(
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
_: int, # _ is a dummy parameter to align with other interfaces
|
||||||
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
||||||
|
) -> None:
|
||||||
|
global deep_gemm_includes, deep_gemm_gemm_template
|
||||||
|
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
||||||
|
_ = jit_tuner.compile_and_tune(
|
||||||
|
name="gemm_fp8_fp8_bf16_nt",
|
||||||
|
keys={
|
||||||
|
"N": n,
|
||||||
|
"K": k,
|
||||||
|
"BLOCK_M": block_m,
|
||||||
|
"BLOCK_N": block_n,
|
||||||
|
"SWIZZLE_D_MODE": smem_config[1],
|
||||||
|
"BLOCK_N_PADDING": smem_config[2],
|
||||||
|
"NUM_STAGES": num_stages,
|
||||||
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
||||||
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
||||||
|
},
|
||||||
|
space=(),
|
||||||
|
includes=deep_gemm_includes,
|
||||||
|
arg_defs=(
|
||||||
|
("lhs", torch.float8_e4m3fn),
|
||||||
|
("lhs_scales", torch.float),
|
||||||
|
("rhs", torch.float8_e4m3fn),
|
||||||
|
("rhs_scales", torch.float),
|
||||||
|
("out", torch.bfloat16),
|
||||||
|
("m", int),
|
||||||
|
("stream", torch.cuda.Stream),
|
||||||
|
("num_sms", int),
|
||||||
|
("smem_size", int),
|
||||||
|
),
|
||||||
|
template=deep_gemm_gemm_template,
|
||||||
|
args=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
||||||
|
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
||||||
|
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
||||||
|
compile_func=_compile_grouped_gemm_nt_f8f8bf16_masked_one,
|
||||||
|
configure_func=lambda m, n, k, num_groups, num_sms: get_best_configs(
|
||||||
|
m, n, k, num_groups, num_sms, is_grouped_masked=True
|
||||||
|
),
|
||||||
|
),
|
||||||
|
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG: DeepGemmKernelHelper(
|
||||||
|
name="m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
|
||||||
|
compile_func=_compile_grouped_gemm_nt_f8f8bf16_contig_one,
|
||||||
|
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
||||||
|
m, n, k, 1, num_sms, is_grouped_contiguous=True
|
||||||
|
),
|
||||||
|
),
|
||||||
|
DeepGemmKernelType.GEMM_NT_F8F8BF16: DeepGemmKernelHelper(
|
||||||
|
name="gemm_fp8_fp8_bf16_nt",
|
||||||
|
compile_func=_compile_gemm_nt_f8f8bf16_one,
|
||||||
|
configure_func=lambda m, n, k, _, num_sms: get_best_configs(
|
||||||
|
m, n, k, 1, num_sms
|
||||||
|
),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_compile_deep_gemm_one_type_all(
|
||||||
|
kernel_type: DeepGemmKernelType,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
num_groups: int,
|
||||||
|
m_list: Optional[List[int]] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
global _INITIALIZATION_DICT
|
||||||
|
global _BUILTIN_M_LIST
|
||||||
|
|
||||||
|
query_key = (kernel_type, n, k, num_groups)
|
||||||
|
if (
|
||||||
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE
|
||||||
|
and _DO_COMPILE
|
||||||
|
and _INITIALIZATION_DICT.get(query_key) is None
|
||||||
|
):
|
||||||
|
_INITIALIZATION_DICT[query_key] = True
|
||||||
|
|
||||||
|
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
||||||
|
_compile_warning_1()
|
||||||
|
logger.info(
|
||||||
|
f"Try DeepGEMM JIT Compiling for "
|
||||||
|
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
||||||
|
f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
||||||
|
num_sms = get_num_sms()
|
||||||
|
collected_configs = set()
|
||||||
|
for m in m_list if m_list is not None else _BUILTIN_M_LIST:
|
||||||
|
# Put config into set to get unique configs and reduce cases to be compiled
|
||||||
|
collected_configs.add(
|
||||||
|
kernel_helper.configure_func(m, n, k, num_groups, num_sms)
|
||||||
|
)
|
||||||
|
compile_func = lambda config: kernel_helper.compile_func(
|
||||||
|
n, k, num_groups, config
|
||||||
|
)
|
||||||
|
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
||||||
|
|
||||||
|
|
||||||
|
def grouped_gemm_nt_f8f8bf16_masked(
|
||||||
|
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
out: torch.Tensor,
|
||||||
|
masked_m: torch.Tensor,
|
||||||
|
expected_m: int,
|
||||||
|
):
|
||||||
|
num_groups, _, k = lhs[0].shape
|
||||||
|
_, n, _ = rhs[0].shape
|
||||||
|
|
||||||
|
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
||||||
|
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||||
|
|
||||||
|
with _log_jit_build(expected_m, n, k, kernel_type):
|
||||||
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||||
|
lhs, rhs, out, masked_m, expected_m
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def grouped_gemm_nt_f8f8bf16_contig(
|
||||||
|
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
out: torch.Tensor,
|
||||||
|
m_indices: torch.Tensor,
|
||||||
|
):
|
||||||
|
m, k = lhs[0].shape
|
||||||
|
num_groups, n, _ = rhs[0].shape
|
||||||
|
|
||||||
|
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
||||||
|
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||||
|
|
||||||
|
with _log_jit_build(m, n, k, kernel_type):
|
||||||
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_nt_f8f8bf16(
|
||||||
|
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
out: torch.Tensor,
|
||||||
|
):
|
||||||
|
m, k = lhs[0].shape
|
||||||
|
n, _ = rhs[0].shape
|
||||||
|
|
||||||
|
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
|
||||||
|
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
|
||||||
|
|
||||||
|
with _log_jit_build(m, n, k, kernel_type):
|
||||||
|
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
||||||
|
if _IN_PRE_COMPILE_STAGE:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
from deep_gemm.jit.runtime import RuntimeCache
|
||||||
|
|
||||||
|
origin_func = RuntimeCache.__getitem__
|
||||||
|
|
||||||
|
def __patched_func(self, *args, **kwargs):
|
||||||
|
ret = origin_func(self, *args, **kwargs)
|
||||||
|
if ret is None:
|
||||||
|
kernel_helper = _KERNEL_HELPER_DICT[kernel_type]
|
||||||
|
_compile_warning_2()
|
||||||
|
logger.warning(
|
||||||
|
f"DeepGEMM JIT Compiling for <{kernel_helper.name}> M={M}, N={N}, K={K}. Please wait."
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
RuntimeCache.__getitem__ = __patched_func
|
||||||
|
yield
|
||||||
|
RuntimeCache.__getitem__ = origin_func
|
||||||
@@ -16,19 +16,17 @@ import functools
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_bool_env_var,
|
|
||||||
get_device_core_count,
|
get_device_core_count,
|
||||||
get_device_name,
|
get_device_name,
|
||||||
get_device_sm,
|
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
supports_custom_op,
|
supports_custom_op,
|
||||||
@@ -43,22 +41,16 @@ else:
|
|||||||
fp8_max = torch.finfo(_fp8_type).max
|
fp8_max = torch.finfo(_fp8_type).max
|
||||||
fp8_min = -fp8_max
|
fp8_min = -fp8_max
|
||||||
|
|
||||||
_enable_jit_deepgemm = False
|
|
||||||
_enable_jit_deepgemm_bmm = False
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
import deep_gemm
|
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
sgl_per_tensor_quant_fp8,
|
sgl_per_tensor_quant_fp8,
|
||||||
sgl_per_token_group_quant_fp8,
|
sgl_per_token_group_quant_fp8,
|
||||||
sgl_per_token_quant_fp8,
|
sgl_per_token_quant_fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
sm_version = get_device_sm()
|
from sglang.srt.layers.quantization.deep_gemm import (
|
||||||
if sm_version == 90:
|
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
|
||||||
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
|
)
|
||||||
_enable_jit_deepgemm = True
|
|
||||||
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
|
|
||||||
_enable_jit_deepgemm_bmm = True
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,10 +63,7 @@ if supports_custom_op():
|
|||||||
Bs: torch.Tensor,
|
Bs: torch.Tensor,
|
||||||
C: torch.Tensor,
|
C: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
M, K = A.shape
|
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||||
N, _ = B.shape
|
|
||||||
with _log_jit_build(M, N, K):
|
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
|
||||||
|
|
||||||
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
@@ -715,25 +704,6 @@ def get_w8a8_block_fp8_configs(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _log_jit_build(M: int, N: int, K: int):
|
|
||||||
from deep_gemm.jit.runtime import RuntimeCache
|
|
||||||
|
|
||||||
origin_func = RuntimeCache.__getitem__
|
|
||||||
|
|
||||||
def __patched_func(self, *args, **kwargs):
|
|
||||||
ret = origin_func(self, *args, **kwargs)
|
|
||||||
if ret is None:
|
|
||||||
logger.warning(
|
|
||||||
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
|
|
||||||
)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
RuntimeCache.__getitem__ = __patched_func
|
|
||||||
yield
|
|
||||||
RuntimeCache.__getitem__ = origin_func
|
|
||||||
|
|
||||||
|
|
||||||
def w8a8_block_fp8_matmul(
|
def w8a8_block_fp8_matmul(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@@ -804,12 +774,11 @@ def w8a8_block_fp8_matmul(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# deepgemm only support bf16
|
# deepgemm only support bf16
|
||||||
if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
|
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||||
else:
|
else:
|
||||||
with _log_jit_build(M, N, K):
|
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
|
|
||||||
else:
|
else:
|
||||||
kernel = (
|
kernel = (
|
||||||
_w8a8_block_fp8_matmul_unrolledx4
|
_w8a8_block_fp8_matmul_unrolledx4
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
VLLM_AVAILABLE = False
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
_enable_jit_deepgemm,
|
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
scaled_fp8_quant,
|
scaled_fp8_quant,
|
||||||
sglang_per_token_quant_fp8,
|
sglang_per_token_quant_fp8,
|
||||||
@@ -143,7 +143,7 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
)
|
)
|
||||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||||
else:
|
else:
|
||||||
if _enable_jit_deepgemm:
|
if _ENABLE_JIT_DEEPGEMM:
|
||||||
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||||
input_2d,
|
input_2d,
|
||||||
block_size[1],
|
block_size[1],
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
||||||
|
from sglang.srt.layers.quantization.deep_gemm import (
|
||||||
|
_ENABLE_JIT_DEEPGEMM,
|
||||||
|
update_deep_gemm_config,
|
||||||
|
)
|
||||||
from sglang.srt.layers.sampler import Sampler
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||||
from sglang.srt.lora.lora_manager import LoRAManager
|
from sglang.srt.lora.lora_manager import LoRAManager
|
||||||
@@ -169,6 +173,10 @@ class ModelRunner:
|
|||||||
# Get memory before model loading
|
# Get memory before model loading
|
||||||
min_per_gpu_memory = self.init_torch_distributed()
|
min_per_gpu_memory = self.init_torch_distributed()
|
||||||
|
|
||||||
|
# Update deep gemm configure
|
||||||
|
if _ENABLE_JIT_DEEPGEMM:
|
||||||
|
update_deep_gemm_config(gpu_id, server_args)
|
||||||
|
|
||||||
# If it is a draft model tp_group can be different.
|
# If it is a draft model tp_group can be different.
|
||||||
self.initialize(min_per_gpu_memory)
|
self.initialize(min_per_gpu_memory)
|
||||||
|
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
|||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
_enable_jit_deepgemm_bmm,
|
|
||||||
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
per_tensor_quant_mla_deep_gemm_masked_fp8,
|
||||||
per_tensor_quant_mla_fp8,
|
per_tensor_quant_mla_fp8,
|
||||||
)
|
)
|
||||||
@@ -86,8 +86,11 @@ _is_hip = is_hip()
|
|||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
|
||||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.deep_gemm import (
|
||||||
|
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from vllm._custom_ops import awq_dequantize
|
from vllm._custom_ops import awq_dequantize
|
||||||
|
|
||||||
@@ -702,7 +705,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
q_nope_out = q_nope.new_empty(
|
q_nope_out = q_nope.new_empty(
|
||||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||||
)
|
)
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
||||||
(q_nope_val, q_nope_scale),
|
(q_nope_val, q_nope_scale),
|
||||||
(self.w_kc, self.w_scale_k),
|
(self.w_kc, self.w_scale_k),
|
||||||
q_nope_out,
|
q_nope_out,
|
||||||
@@ -751,7 +754,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
attn_bmm_output = attn_output.new_empty(
|
attn_bmm_output = attn_output.new_empty(
|
||||||
(self.num_local_heads, aligned_m, self.v_head_dim)
|
(self.num_local_heads, aligned_m, self.v_head_dim)
|
||||||
)
|
)
|
||||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
||||||
(attn_output_val, attn_output_scale),
|
(attn_output_val, attn_output_scale),
|
||||||
(self.w_vc, self.w_scale_v),
|
(self.w_vc, self.w_scale_v),
|
||||||
attn_bmm_output,
|
attn_bmm_output,
|
||||||
@@ -1520,7 +1523,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
_is_cuda
|
_is_cuda
|
||||||
and _enable_jit_deepgemm_bmm
|
and _ENABLE_JIT_DEEPGEMM
|
||||||
and weight_block_size[0] == 128
|
and weight_block_size[0] == 128
|
||||||
and weight_block_size[1] == 128
|
and weight_block_size[1] == 128
|
||||||
and model_dtype == torch.bfloat16
|
and model_dtype == torch.bfloat16
|
||||||
|
|||||||
@@ -98,6 +98,16 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
|
|||||||
return value in truthy_values
|
return value in truthy_values
|
||||||
|
|
||||||
|
|
||||||
|
def get_int_env_var(name: str, default: int = 0) -> int:
|
||||||
|
value = os.getenv(name)
|
||||||
|
if value is None or not value.strip():
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
||||||
def is_hip() -> bool:
|
def is_hip() -> bool:
|
||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user