Refactor DeepGEMM integration (#7150)
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.math_utils import ceil_div
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.utils import dispose_tensor, is_cuda
|
||||
|
||||
@@ -15,11 +16,6 @@ if _is_cuda:
|
||||
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
try:
|
||||
from deep_gemm import ceil_div
|
||||
except ImportError:
|
||||
logger.error(f"Failed to import ceil_div from deep_gemm.")
|
||||
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
|
||||
@@ -1,30 +1,11 @@
|
||||
import logging
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from sgl_kernel import silu_and_mul
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
try:
|
||||
from deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
)
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
use_deep_gemm = True
|
||||
except ImportError:
|
||||
use_deep_gemm = False
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
scaled_fp8_quant,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
sglang_per_token_quant_fp8,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
self.deepep_mode = deepep_mode
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
assert (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
|
||||
):
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
return self.forward_deepgemm_contiguous(
|
||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||
)
|
||||
@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||
)
|
||||
del input_tensor
|
||||
@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
del down_input
|
||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||
(down_input_fp8, down_input_scale),
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
|
||||
gateup_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
hidden_states_fp8,
|
||||
self.w13_weight_fp8,
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8[0])
|
||||
|
||||
@@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE):
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
down_input_fp8,
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
|
||||
)
|
||||
|
||||
return down_output
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
@@ -236,14 +236,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
# TODO hard code 128 block quant,use fp8 communication
|
||||
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return hidden_states, topk_idx, topk_weights, previous_event
|
||||
|
||||
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
@@ -345,7 +345,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
previous_event=previous_event,
|
||||
async_finish=self.async_finish,
|
||||
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
||||
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
|
||||
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
||||
)
|
||||
|
||||
@@ -409,7 +409,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
output = hidden_states
|
||||
else:
|
||||
if hidden_states.shape[0] > 0:
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from .entrypoint import *
|
||||
@@ -5,33 +5,24 @@ from dataclasses import dataclass
|
||||
from enum import IntEnum, auto
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
||||
DEEPGEMM_V202506,
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
|
||||
from sglang.srt.utils import get_bool_env_var, get_int_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_ENABLE_JIT_DEEPGEMM = False
|
||||
|
||||
try:
|
||||
import deep_gemm
|
||||
from deep_gemm import get_num_sms
|
||||
from deep_gemm.jit import build
|
||||
from deep_gemm.jit.compiler import get_nvcc_compiler
|
||||
from deep_gemm.jit_kernels.gemm import get_best_configs
|
||||
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
||||
|
||||
sm_version = get_device_sm()
|
||||
if sm_version == 90:
|
||||
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
||||
_ENABLE_JIT_DEEPGEMM = True
|
||||
except ImportError:
|
||||
logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
|
||||
|
||||
|
||||
def get_enable_jit_deepgemm():
|
||||
return _ENABLE_JIT_DEEPGEMM
|
||||
pass
|
||||
|
||||
|
||||
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
||||
@@ -52,8 +43,10 @@ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
||||
# NVRTC may have performance loss with some cases.
|
||||
# And NVCC JIT speed is also 9x faster in the ref commit
|
||||
_USE_NVRTC_DEFAULT = "0"
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
try:
|
||||
from deep_gemm.jit.compiler import get_nvcc_compiler
|
||||
|
||||
get_nvcc_compiler()
|
||||
except:
|
||||
logger.warning(
|
||||
@@ -114,6 +107,7 @@ class DeepGemmKernelHelper:
|
||||
_INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dict()
|
||||
|
||||
|
||||
# TODO improve naming
|
||||
def _compile_warning_1():
|
||||
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
||||
logger.warning(
|
||||
@@ -127,6 +121,7 @@ def _compile_warning_1():
|
||||
)
|
||||
|
||||
|
||||
# TODO improve naming
|
||||
def _compile_warning_2():
|
||||
logger.warning(
|
||||
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
||||
@@ -238,6 +233,7 @@ def _compile_gemm_nt_f8f8bf16_one(
|
||||
_ = build("gemm_fp8_fp8_bf16_nt", code, FP8GemmRuntime, kwargs)
|
||||
|
||||
|
||||
# TODO further refactor warmup-related
|
||||
_KERNEL_HELPER_DICT: Dict[DeepGemmKernelType, DeepGemmKernelHelper] = {
|
||||
DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED: DeepGemmKernelHelper(
|
||||
name="m_grouped_gemm_fp8_fp8_bf16_nt_masked",
|
||||
@@ -270,7 +266,6 @@ def _maybe_compile_deep_gemm_one_type_all(
|
||||
num_groups: int,
|
||||
m_list: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
|
||||
global _INITIALIZATION_DICT
|
||||
global _BUILTIN_M_LIST
|
||||
|
||||
@@ -304,56 +299,6 @@ def _maybe_compile_deep_gemm_one_type_all(
|
||||
thread_map(compile_func, collected_configs, max_workers=_COMPILE_WORKERS)
|
||||
|
||||
|
||||
def grouped_gemm_nt_f8f8bf16_masked(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
):
|
||||
num_groups, _, k = lhs[0].shape
|
||||
_, n, _ = rhs[0].shape
|
||||
|
||||
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
||||
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||
|
||||
with _log_jit_build(expected_m, n, k, kernel_type):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
lhs, rhs, out, masked_m, expected_m
|
||||
)
|
||||
|
||||
|
||||
def grouped_gemm_nt_f8f8bf16_contig(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
m_indices: torch.Tensor,
|
||||
):
|
||||
m, k = lhs[0].shape
|
||||
num_groups, n, _ = rhs[0].shape
|
||||
|
||||
kernel_type = DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
||||
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||
|
||||
with _log_jit_build(m, n, k, kernel_type):
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs, rhs, out, m_indices)
|
||||
|
||||
|
||||
def gemm_nt_f8f8bf16(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
):
|
||||
m, k = lhs[0].shape
|
||||
n, _ = rhs[0].shape
|
||||
|
||||
kernel_type = DeepGemmKernelType.GEMM_NT_F8F8BF16
|
||||
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, 1)
|
||||
|
||||
with _log_jit_build(m, n, k, kernel_type):
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
||||
if _IN_PRECOMPILE_STAGE:
|
||||
@@ -380,13 +325,14 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def configure_deep_gemm_num_sms(num_sms):
|
||||
if num_sms is None:
|
||||
def deep_gemm_execution_hook(
|
||||
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
|
||||
):
|
||||
# not supported yet
|
||||
if DEEPGEMM_V202506:
|
||||
yield
|
||||
return
|
||||
|
||||
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
|
||||
with _log_jit_build(m, n, k, kernel_type):
|
||||
yield
|
||||
else:
|
||||
original_num_sms = deep_gemm.get_num_sms()
|
||||
deep_gemm.set_num_sms(num_sms)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deep_gemm.set_num_sms(original_num_sms)
|
||||
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, get_device_sm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _compute_enable_deep_gemm():
|
||||
try:
|
||||
import deep_gemm
|
||||
except ImportError:
|
||||
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
|
||||
return False
|
||||
|
||||
sm_version = get_device_sm()
|
||||
if sm_version < 90:
|
||||
return False
|
||||
|
||||
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
||||
|
||||
|
||||
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
||||
|
||||
DEEPGEMM_V202506 = False
|
||||
|
||||
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_V202506
|
||||
@@ -0,0 +1,95 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper import compile_utils
|
||||
from sglang.srt.layers.quantization.deep_gemm_wrapper.configurer import (
|
||||
DEEPGEMM_SCALE_UE8M0,
|
||||
DEEPGEMM_V202506,
|
||||
ENABLE_JIT_DEEPGEMM,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if ENABLE_JIT_DEEPGEMM:
|
||||
import deep_gemm
|
||||
from deep_gemm import gemm_fp8_fp8_bf16_nt as _gemm_nt_f8f8bf16_raw
|
||||
from deep_gemm import get_col_major_tma_aligned_tensor
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous as _grouped_gemm_nt_f8f8bf16_contig_raw,
|
||||
)
|
||||
from deep_gemm import (
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked as _grouped_gemm_nt_f8f8bf16_masked_raw,
|
||||
)
|
||||
|
||||
|
||||
def grouped_gemm_nt_f8f8bf16_masked(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
recipe=None,
|
||||
):
|
||||
num_groups, _, k = lhs[0].shape
|
||||
_, n, _ = rhs[0].shape
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_MASKED
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(
|
||||
expected_m, n, k, num_groups, kernel_type
|
||||
):
|
||||
_grouped_gemm_nt_f8f8bf16_masked_raw(
|
||||
lhs, rhs, out, masked_m, expected_m, recipe=recipe
|
||||
)
|
||||
|
||||
|
||||
def grouped_gemm_nt_f8f8bf16_contig(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
m_indices: torch.Tensor,
|
||||
):
|
||||
m, k = lhs[0].shape
|
||||
num_groups, n, _ = rhs[0].shape
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||
_grouped_gemm_nt_f8f8bf16_contig_raw(lhs, rhs, out, m_indices)
|
||||
|
||||
|
||||
def gemm_nt_f8f8bf16(
|
||||
lhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
rhs: Tuple[torch.Tensor, torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
):
|
||||
m, k = lhs[0].shape
|
||||
n, _ = rhs[0].shape
|
||||
num_groups = 1
|
||||
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
|
||||
|
||||
with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
|
||||
_gemm_nt_f8f8bf16_raw(
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
)
|
||||
|
||||
|
||||
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
||||
compile_utils.update_deep_gemm_config(gpu_id, server_args)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def configure_deep_gemm_num_sms(num_sms):
|
||||
if num_sms is None:
|
||||
yield
|
||||
else:
|
||||
original_num_sms = deep_gemm.get_num_sms()
|
||||
deep_gemm.set_num_sms(num_sms)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deep_gemm.set_num_sms(original_num_sms)
|
||||
@@ -23,7 +23,8 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.math_utils import align
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_device_core_count,
|
||||
@@ -44,10 +45,6 @@ if _is_cuda:
|
||||
sgl_per_token_quant_fp8,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import (
|
||||
gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -67,7 +64,6 @@ else:
|
||||
fp8_max = torch.finfo(fp8_dtype).max
|
||||
fp8_min = -fp8_max
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt(
|
||||
@@ -77,7 +73,7 @@ if supports_custom_op():
|
||||
Bs: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
) -> None:
|
||||
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||
deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||
|
||||
def deep_gemm_fp8_fp8_bf16_nt_fake(
|
||||
A: torch.Tensor,
|
||||
@@ -797,12 +793,12 @@ def w8a8_block_fp8_matmul_deepgemm(
|
||||
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
|
||||
|
||||
# Deepgemm only supports output tensor type as bfloat16
|
||||
assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM
|
||||
assert C.dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
else:
|
||||
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||
deep_gemm_wrapper.gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||
|
||||
return C
|
||||
|
||||
@@ -896,7 +892,7 @@ def w8a8_block_fp8_matmul(
|
||||
block_size: List[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
||||
if output_dtype == torch.bfloat16 and deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
return w8a8_block_fp8_matmul_deepgemm(
|
||||
A, B, As, Bs, block_size, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
from curses import flash
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
|
||||
from sglang.math_utils import align
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
|
||||
@@ -15,7 +15,6 @@ try:
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
fp8_dtype,
|
||||
fp8_max,
|
||||
@@ -138,7 +137,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
|
||||
return cutlass_w8a8_block_fp8_linear_with_fallback
|
||||
elif _use_aiter:
|
||||
return aiter_w8a8_block_fp8_linear
|
||||
elif _ENABLE_JIT_DEEPGEMM:
|
||||
elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
return deepgemm_w8a8_block_fp8_linear_with_fallback
|
||||
else:
|
||||
return triton_w8a8_block_fp8_linear
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt import debug_utils
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
@@ -45,10 +46,9 @@ from sglang.srt.layers.dp_attention import (
|
||||
initialize_dp_attention,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
||||
from sglang.srt.layers.quantization.deep_gemm import (
|
||||
_ENABLE_JIT_DEEPGEMM,
|
||||
update_deep_gemm_config,
|
||||
from sglang.srt.layers.quantization import (
|
||||
deep_gemm_wrapper,
|
||||
monkey_patch_isinstance_for_vllm_base_layer,
|
||||
)
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
@@ -205,8 +205,8 @@ class ModelRunner:
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
|
||||
# Update deep gemm configure
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
update_deep_gemm_config(gpu_id, server_args)
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
||||
|
||||
# If it is a draft model, tp_group can be different
|
||||
self.initialize(min_per_gpu_memory)
|
||||
|
||||
@@ -54,8 +54,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
per_tensor_quant_mla_fp8,
|
||||
@@ -110,10 +110,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import (
|
||||
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
|
||||
)
|
||||
else:
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
|
||||
@@ -981,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope.new_empty(
|
||||
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
||||
)
|
||||
deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
(q_nope_val, q_nope_scale),
|
||||
(self.w_kc, self.w_scale_k),
|
||||
q_nope_out,
|
||||
@@ -1851,7 +1847,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
and weight_block_size[1] == 128
|
||||
and model_dtype == torch.bfloat16
|
||||
):
|
||||
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
||||
"SGL_USE_DEEPGEMM_BMM", "false"
|
||||
):
|
||||
block_scale = weight_scale
|
||||
|
||||
@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
|
||||
ScatterMode,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
@@ -479,7 +479,9 @@ def _model_forward_tbo(
|
||||
)
|
||||
del inputs
|
||||
|
||||
with configure_deep_gemm_num_sms(operations_strategy.deep_gemm_num_sms):
|
||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||
operations_strategy.deep_gemm_num_sms
|
||||
):
|
||||
outputs_arr = execute_overlapped_operations(
|
||||
inputs_arr=inputs_arr,
|
||||
operations_arr=[operations_strategy.operations] * 2,
|
||||
|
||||
Reference in New Issue
Block a user