[1/N]Support DeepSeek-R1 w4a8 normal deepep (#8247)
Co-authored-by: Hank Han <hanhan7630@outlook.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Cutlass W4A8 MoE kernel."""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -11,6 +12,9 @@ from sgl_kernel import (
|
||||
)
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_permute_triton_kernel,
|
||||
deepep_post_reorder_triton_kernel,
|
||||
deepep_run_moe_deep_preprocess,
|
||||
post_reorder_triton_kernel_for_cutlass_moe,
|
||||
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||
run_moe_ep_preproess,
|
||||
@@ -201,3 +205,195 @@ def cutlass_w4a8_moe(
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def cutlass_w4a8_moe_deepep_normal(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids_: torch.Tensor,
|
||||
a_strides1: torch.Tensor,
|
||||
b_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
a_strides2: torch.Tensor,
|
||||
b_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
s_strides13: torch.Tensor,
|
||||
s_strides2: torch.Tensor,
|
||||
expert_offsets: torch.Tensor,
|
||||
problem_sizes1: torch.Tensor,
|
||||
problem_sizes2: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
||||
Shape: [num_experts, N * 2, K // 2]
|
||||
(the weights are passed transposed and int4-packed)
|
||||
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
||||
Shape: [num_experts, K, N // 2]
|
||||
(the weights are passed transposed and int4-packed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts, K // 512, N * 8]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts, N // 512, K * 4]
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
||||
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
||||
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
||||
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
||||
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
||||
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [1, K]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [1, N]
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is 1.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
||||
"""
|
||||
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.int8
|
||||
assert w2_q.dtype == torch.int8
|
||||
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
||||
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||
|
||||
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
||||
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
||||
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
||||
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
||||
num_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
||||
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
||||
topk = topk_ids_.size(1)
|
||||
|
||||
num_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(2) * 2
|
||||
n = w2_q.size(2) * 2
|
||||
topk = topk_ids_.size(1)
|
||||
device = a.device
|
||||
|
||||
reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess(
|
||||
topk_ids_, num_experts
|
||||
)
|
||||
num_total_tokens = reorder_topk_ids.numel()
|
||||
gateup_input_pre_reorder = torch.empty(
|
||||
(int(num_total_tokens), a.shape[1]),
|
||||
device=device,
|
||||
dtype=a.dtype,
|
||||
)
|
||||
deepep_permute_triton_kernel[(a.shape[0],)](
|
||||
a,
|
||||
gateup_input_pre_reorder,
|
||||
src2dst,
|
||||
topk_ids_.to(torch.int64),
|
||||
None,
|
||||
topk,
|
||||
a.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
gateup_input = torch.empty(
|
||||
gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device
|
||||
)
|
||||
sgl_per_tensor_quant_fp8(
|
||||
gateup_input_pre_reorder, gateup_input, a1_scale.float(), True
|
||||
)
|
||||
del gateup_input_pre_reorder
|
||||
local_topk_ids = topk_ids_
|
||||
local_topk_ids = (
|
||||
torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32)
|
||||
).contiguous()
|
||||
|
||||
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
get_cutlass_w4a8_moe_mm_data(
|
||||
local_topk_ids,
|
||||
expert_offsets,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
a_map,
|
||||
c_map,
|
||||
num_experts,
|
||||
n,
|
||||
k,
|
||||
)
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
||||
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
||||
|
||||
cutlass_w4a8_moe_mm(
|
||||
c1,
|
||||
gateup_input,
|
||||
w1_q,
|
||||
a1_scale.float(),
|
||||
w1_scale,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes1,
|
||||
a_strides1,
|
||||
b_strides1,
|
||||
c_strides1,
|
||||
s_strides13,
|
||||
128,
|
||||
topk,
|
||||
)
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
intermediate_q = torch.empty(
|
||||
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
||||
)
|
||||
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
||||
|
||||
cutlass_w4a8_moe_mm(
|
||||
c2,
|
||||
intermediate_q,
|
||||
w2_q,
|
||||
a2_scale.float(),
|
||||
w2_scale,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes2,
|
||||
a_strides2,
|
||||
b_strides2,
|
||||
c_strides2,
|
||||
s_strides2,
|
||||
128,
|
||||
topk,
|
||||
)
|
||||
num_tokens = src2dst.shape[0] // topk
|
||||
output = torch.empty(
|
||||
(num_tokens, c2.shape[1]),
|
||||
device=c2.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
c2,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids_,
|
||||
topk_weights,
|
||||
topk,
|
||||
c2.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -29,6 +29,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
CUTEDSL_MOE_NVFP4_DISPATCH,
|
||||
ModelOptNvFp4FusedMoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||
@@ -96,6 +97,11 @@ class DeepEPMoE(FusedMoE):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.use_fp8_w8a8 = True
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.use_w4afp8 = False
|
||||
elif isinstance(quant_config, W4AFp8Config):
|
||||
self.use_w4afp8 = True
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
else:
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
@@ -142,7 +148,7 @@ class DeepEPMoE(FusedMoE):
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
if self.use_block_quant or self.use_w4afp8
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
@@ -150,7 +156,7 @@ class DeepEPMoE(FusedMoE):
|
||||
self.w2_weight,
|
||||
(
|
||||
self.w2_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
if self.use_block_quant or self.use_w4afp8
|
||||
else self.w2_weight_scale
|
||||
),
|
||||
)
|
||||
@@ -210,6 +216,8 @@ class DeepEPMoE(FusedMoE):
|
||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||
return self.forward_npu(dispatch_output)
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
if self.use_w4afp8:
|
||||
return self.forward_cutlass_w4afp8(dispatch_output)
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
@@ -438,6 +446,17 @@ class DeepEPMoE(FusedMoE):
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_cutlass_w4afp8(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
):
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
assert isinstance(self.quant_method, W4AFp8MoEMethod)
|
||||
return self.quant_method.apply_deepep_normal(
|
||||
layer=self,
|
||||
dispatch_output=dispatch_output,
|
||||
)
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
|
||||
@@ -14,7 +14,12 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
|
||||
from sglang.srt.layers.moe.utils import (
|
||||
DeepEPMode,
|
||||
get_deepep_config,
|
||||
get_moe_runner_backend,
|
||||
is_tbo_enabled,
|
||||
)
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
@@ -340,7 +345,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
if (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
and not get_moe_runner_backend().is_cutlass()
|
||||
):
|
||||
# TODO hard code 128 block quant,use fp8 communication
|
||||
hidden_states = sglang_per_token_group_quant_fp8(
|
||||
hidden_states,
|
||||
@@ -386,7 +394,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
async_finish=self.async_finish,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
)
|
||||
|
||||
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
||||
# However, doing this would incur an unknown synchronization error, but keeping
|
||||
# `handle` as a member variable works.
|
||||
@@ -412,7 +419,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
|
||||
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
||||
)
|
||||
|
||||
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
||||
num_recv_tokens_per_expert,
|
||||
num_tokens_per_rank=num_tokens_per_rank,
|
||||
|
||||
@@ -55,6 +55,7 @@ class MoeRunnerBackend(Enum):
|
||||
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
||||
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
||||
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
|
||||
CUTLASS = "cutlass"
|
||||
|
||||
def is_auto(self):
|
||||
return self == MoeRunnerBackend.AUTO
|
||||
@@ -80,6 +81,9 @@ class MoeRunnerBackend(Enum):
|
||||
def is_flashinfer_mxfp4(self):
|
||||
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
||||
|
||||
def is_cutlass(self):
|
||||
return self == MoeRunnerBackend.CUTLASS
|
||||
|
||||
|
||||
class DeepEPMode(Enum):
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -21,8 +21,10 @@ from sglang.srt.utils import is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
CombineInput,
|
||||
DeepEPNormalOutput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
@@ -326,3 +328,47 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= self.moe_runner_config.routed_scaling_factor
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
def apply_deepep_normal(
|
||||
self,
|
||||
layer: DeepEPMoE,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import (
|
||||
cutlass_w4a8_moe_deepep_normal,
|
||||
)
|
||||
|
||||
hidden_states, topk_idx, topk_weights = (
|
||||
dispatch_output.hidden_states,
|
||||
dispatch_output.topk_idx,
|
||||
dispatch_output.topk_weights,
|
||||
)
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
if num_tokens > 0:
|
||||
return cutlass_w4a8_moe_deepep_normal(
|
||||
hidden_states,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layer.w13_weight_scale_inv,
|
||||
layer.w2_weight_scale_inv,
|
||||
topk_weights,
|
||||
topk_idx,
|
||||
self.a_strides1,
|
||||
self.b_strides1,
|
||||
self.c_strides1,
|
||||
self.a_strides2,
|
||||
self.b_strides2,
|
||||
self.c_strides2,
|
||||
self.s_strides13,
|
||||
self.s_strides2,
|
||||
self.expert_offsets,
|
||||
self.problem_sizes1,
|
||||
self.problem_sizes2,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@@ -137,6 +137,7 @@ MOE_RUNNER_BACKEND_CHOICES = [
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_mxfp4",
|
||||
"flashinfer_cutedsl",
|
||||
"cutlass",
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user