Refactor DeepGEMM integration (#7150)
This commit is contained in:
@@ -1,30 +1,11 @@
|
||||
import logging
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from sgl_kernel import silu_and_mul
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
try:
|
||||
from deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
)
|
||||
from sgl_kernel import silu_and_mul
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
use_deep_gemm = True
|
||||
except ImportError:
|
||||
use_deep_gemm = False
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -45,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -52,10 +34,20 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
scaled_fp8_quant,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
sglang_per_token_quant_fp8,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
@@ -680,7 +672,6 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
@@ -920,7 +911,9 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
self.deepep_mode = deepep_mode
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
assert (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
@@ -948,7 +941,7 @@ class DeepEPMoE(EPMoE):
|
||||
):
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
return self.forward_deepgemm_contiguous(
|
||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||
)
|
||||
@@ -1145,7 +1138,7 @@ class DeepEPMoE(EPMoE):
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||
)
|
||||
del input_tensor
|
||||
@@ -1169,7 +1162,7 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
del down_input
|
||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||
(down_input_fp8, down_input_scale),
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
@@ -1202,8 +1195,13 @@ class DeepEPMoE(EPMoE):
|
||||
gateup_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
hidden_states_fp8,
|
||||
self.w13_weight_fp8,
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8[0])
|
||||
|
||||
@@ -1240,13 +1238,18 @@ class DeepEPMoE(EPMoE):
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
down_input_fp8,
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_V202506 else None,
|
||||
)
|
||||
|
||||
return down_output
|
||||
|
||||
Reference in New Issue
Block a user