Refactor DeepGEMM integration (#7150)

This commit is contained in:
fzyzcjy
2025-06-14 11:41:03 +08:00
committed by GitHub
parent 8b8f2e7463
commit b4c41f7276
12 changed files with 207 additions and 147 deletions

View File

@@ -1,30 +1,11 @@
import logging
from typing import Callable, List, Optional, Tuple
import einops
import torch
from sgl_kernel import silu_and_mul
from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
try:
from deep_gemm import (
get_col_major_tma_aligned_tensor,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
)
from sgl_kernel import silu_and_mul
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
use_deep_gemm = True
except ImportError:
use_deep_gemm = False
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant,
sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8,
)
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
from sglang.srt.utils import (
DeepEPMode,
dispose_tensor,
get_bool_env_var,
is_hip,
set_weight_attrs,
)
_is_hip = is_hip()
@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
)
self.deepep_mode = deepep_mode
if self.deepep_mode.enable_low_latency():
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
self.w13_weight_fp8 = (
self.w13_weight,
(
@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
):
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
if resolved_deepep_mode == DeepEPMode.normal:
if _ENABLE_JIT_DEEPGEMM:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
return self.forward_deepgemm_contiguous(
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
)
@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
dtype=torch.bfloat16,
)
input_tensor[1] = tma_align_input_scale(input_tensor[1])
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
)
del input_tensor
@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
)
del down_input
down_input_scale = tma_align_input_scale(down_input_scale)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
self.w2_weight_fp8,
down_output,
@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
hidden_states_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
)
dispose_tensor(hidden_states_fp8[0])
@@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE):
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
get_col_major_tma_aligned_tensor(down_input_scale),
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
)
down_output = torch.empty(
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
)
return down_output