[1/N]Support DeepSeek-R1 w4a8 normal deepep (#8247)
Co-authored-by: Hank Han <hanhan7630@outlook.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Cutlass W4A8 MoE kernel."""
|
"""Cutlass W4A8 MoE kernel."""
|
||||||
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,6 +12,9 @@ from sgl_kernel import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
|
deepep_permute_triton_kernel,
|
||||||
|
deepep_post_reorder_triton_kernel,
|
||||||
|
deepep_run_moe_deep_preprocess,
|
||||||
post_reorder_triton_kernel_for_cutlass_moe,
|
post_reorder_triton_kernel_for_cutlass_moe,
|
||||||
pre_reorder_triton_kernel_for_cutlass_moe,
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||||
run_moe_ep_preproess,
|
run_moe_ep_preproess,
|
||||||
@@ -201,3 +205,195 @@ def cutlass_w4a8_moe(
|
|||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_w4a8_moe_deepep_normal(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1_q: torch.Tensor,
|
||||||
|
w2_q: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids_: torch.Tensor,
|
||||||
|
a_strides1: torch.Tensor,
|
||||||
|
b_strides1: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
a_strides2: torch.Tensor,
|
||||||
|
b_strides2: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
s_strides13: torch.Tensor,
|
||||||
|
s_strides2: torch.Tensor,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
problem_sizes1: torch.Tensor,
|
||||||
|
problem_sizes2: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
||||||
|
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||||
|
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||||
|
grouped gemm.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
Shape: [M, K]
|
||||||
|
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
||||||
|
Shape: [num_experts, N * 2, K // 2]
|
||||||
|
(the weights are passed transposed and int4-packed)
|
||||||
|
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
||||||
|
Shape: [num_experts, K, N // 2]
|
||||||
|
(the weights are passed transposed and int4-packed)
|
||||||
|
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||||
|
Shape: [num_experts, K // 512, N * 8]
|
||||||
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||||
|
Shape: [num_experts, N // 512, K * 4]
|
||||||
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||||
|
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
||||||
|
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
||||||
|
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||||
|
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
||||||
|
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
||||||
|
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
||||||
|
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
||||||
|
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
||||||
|
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||||
|
Shape: scalar or [1, K]
|
||||||
|
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||||
|
quantize the intermediate result between the gemms.
|
||||||
|
Shape: scalar or [1, N]
|
||||||
|
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||||
|
applied directly on the inputs. This is only applicable when topk is 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
||||||
|
assert w1_q.dtype == torch.int8
|
||||||
|
assert w2_q.dtype == torch.int8
|
||||||
|
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
||||||
|
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
||||||
|
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||||
|
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||||
|
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||||
|
|
||||||
|
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
||||||
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
||||||
|
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
||||||
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
||||||
|
num_experts = w1_q.size(0)
|
||||||
|
m = a.size(0)
|
||||||
|
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
||||||
|
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
||||||
|
topk = topk_ids_.size(1)
|
||||||
|
|
||||||
|
num_experts = w1_q.size(0)
|
||||||
|
m = a.size(0)
|
||||||
|
k = w1_q.size(2) * 2
|
||||||
|
n = w2_q.size(2) * 2
|
||||||
|
topk = topk_ids_.size(1)
|
||||||
|
device = a.device
|
||||||
|
|
||||||
|
reorder_topk_ids, src2dst, _ = deepep_run_moe_deep_preprocess(
|
||||||
|
topk_ids_, num_experts
|
||||||
|
)
|
||||||
|
num_total_tokens = reorder_topk_ids.numel()
|
||||||
|
gateup_input_pre_reorder = torch.empty(
|
||||||
|
(int(num_total_tokens), a.shape[1]),
|
||||||
|
device=device,
|
||||||
|
dtype=a.dtype,
|
||||||
|
)
|
||||||
|
deepep_permute_triton_kernel[(a.shape[0],)](
|
||||||
|
a,
|
||||||
|
gateup_input_pre_reorder,
|
||||||
|
src2dst,
|
||||||
|
topk_ids_.to(torch.int64),
|
||||||
|
None,
|
||||||
|
topk,
|
||||||
|
a.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
gateup_input = torch.empty(
|
||||||
|
gateup_input_pre_reorder.shape, dtype=torch.float8_e4m3fn, device=device
|
||||||
|
)
|
||||||
|
sgl_per_tensor_quant_fp8(
|
||||||
|
gateup_input_pre_reorder, gateup_input, a1_scale.float(), True
|
||||||
|
)
|
||||||
|
del gateup_input_pre_reorder
|
||||||
|
local_topk_ids = topk_ids_
|
||||||
|
local_topk_ids = (
|
||||||
|
torch.where(local_topk_ids == -1, num_experts, topk_ids_).to(torch.int32)
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
|
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
|
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
|
get_cutlass_w4a8_moe_mm_data(
|
||||||
|
local_topk_ids,
|
||||||
|
expert_offsets,
|
||||||
|
problem_sizes1,
|
||||||
|
problem_sizes2,
|
||||||
|
a_map,
|
||||||
|
c_map,
|
||||||
|
num_experts,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
)
|
||||||
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
||||||
|
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
cutlass_w4a8_moe_mm(
|
||||||
|
c1,
|
||||||
|
gateup_input,
|
||||||
|
w1_q,
|
||||||
|
a1_scale.float(),
|
||||||
|
w1_scale,
|
||||||
|
expert_offsets[:-1],
|
||||||
|
problem_sizes1,
|
||||||
|
a_strides1,
|
||||||
|
b_strides1,
|
||||||
|
c_strides1,
|
||||||
|
s_strides13,
|
||||||
|
128,
|
||||||
|
topk,
|
||||||
|
)
|
||||||
|
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
||||||
|
silu_and_mul(c1, intermediate)
|
||||||
|
|
||||||
|
intermediate_q = torch.empty(
|
||||||
|
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
||||||
|
)
|
||||||
|
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
||||||
|
|
||||||
|
cutlass_w4a8_moe_mm(
|
||||||
|
c2,
|
||||||
|
intermediate_q,
|
||||||
|
w2_q,
|
||||||
|
a2_scale.float(),
|
||||||
|
w2_scale,
|
||||||
|
expert_offsets[:-1],
|
||||||
|
problem_sizes2,
|
||||||
|
a_strides2,
|
||||||
|
b_strides2,
|
||||||
|
c_strides2,
|
||||||
|
s_strides2,
|
||||||
|
128,
|
||||||
|
topk,
|
||||||
|
)
|
||||||
|
num_tokens = src2dst.shape[0] // topk
|
||||||
|
output = torch.empty(
|
||||||
|
(num_tokens, c2.shape[1]),
|
||||||
|
device=c2.device,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||||
|
c2,
|
||||||
|
output,
|
||||||
|
src2dst,
|
||||||
|
topk_ids_,
|
||||||
|
topk_weights,
|
||||||
|
topk,
|
||||||
|
c2.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|||||||
CUTEDSL_MOE_NVFP4_DISPATCH,
|
CUTEDSL_MOE_NVFP4_DISPATCH,
|
||||||
ModelOptNvFp4FusedMoEMethod,
|
ModelOptNvFp4FusedMoEMethod,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||||
@@ -96,6 +97,11 @@ class DeepEPMoE(FusedMoE):
|
|||||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
self.use_fp8_w8a8 = True
|
self.use_fp8_w8a8 = True
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
|
self.use_w4afp8 = False
|
||||||
|
elif isinstance(quant_config, W4AFp8Config):
|
||||||
|
self.use_w4afp8 = True
|
||||||
|
self.use_fp8_w8a8 = False
|
||||||
|
self.use_block_quant = False
|
||||||
else:
|
else:
|
||||||
self.use_fp8_w8a8 = False
|
self.use_fp8_w8a8 = False
|
||||||
self.use_block_quant = False
|
self.use_block_quant = False
|
||||||
@@ -142,7 +148,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
self.w13_weight,
|
self.w13_weight,
|
||||||
(
|
(
|
||||||
self.w13_weight_scale_inv
|
self.w13_weight_scale_inv
|
||||||
if self.use_block_quant
|
if self.use_block_quant or self.use_w4afp8
|
||||||
else self.w13_weight_scale
|
else self.w13_weight_scale
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -150,7 +156,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
self.w2_weight,
|
self.w2_weight,
|
||||||
(
|
(
|
||||||
self.w2_weight_scale_inv
|
self.w2_weight_scale_inv
|
||||||
if self.use_block_quant
|
if self.use_block_quant or self.use_w4afp8
|
||||||
else self.w2_weight_scale
|
else self.w2_weight_scale
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -210,6 +216,8 @@ class DeepEPMoE(FusedMoE):
|
|||||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||||
return self.forward_npu(dispatch_output)
|
return self.forward_npu(dispatch_output)
|
||||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||||
|
if self.use_w4afp8:
|
||||||
|
return self.forward_cutlass_w4afp8(dispatch_output)
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||||
@@ -438,6 +446,17 @@ class DeepEPMoE(FusedMoE):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def forward_cutlass_w4afp8(
|
||||||
|
self,
|
||||||
|
dispatch_output: DeepEPNormalOutput,
|
||||||
|
):
|
||||||
|
assert self.moe_runner_config.activation == "silu"
|
||||||
|
assert isinstance(self.quant_method, W4AFp8MoEMethod)
|
||||||
|
return self.quant_method.apply_deepep_normal(
|
||||||
|
layer=self,
|
||||||
|
dispatch_output=dispatch_output,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_deepgemm_masked(
|
def forward_deepgemm_masked(
|
||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPLLOutput,
|
dispatch_output: DeepEPLLOutput,
|
||||||
|
|||||||
@@ -14,7 +14,12 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
|||||||
DispatchOutput,
|
DispatchOutput,
|
||||||
DispatchOutputFormat,
|
DispatchOutputFormat,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
|
from sglang.srt.layers.moe.utils import (
|
||||||
|
DeepEPMode,
|
||||||
|
get_deepep_config,
|
||||||
|
get_moe_runner_backend,
|
||||||
|
is_tbo_enabled,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -340,7 +345,10 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_idx = topk_idx.to(torch.int64)
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
if (
|
||||||
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
|
and not get_moe_runner_backend().is_cutlass()
|
||||||
|
):
|
||||||
# TODO hard code 128 block quant,use fp8 communication
|
# TODO hard code 128 block quant,use fp8 communication
|
||||||
hidden_states = sglang_per_token_group_quant_fp8(
|
hidden_states = sglang_per_token_group_quant_fp8(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -386,7 +394,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
async_finish=self.async_finish,
|
async_finish=self.async_finish,
|
||||||
allocate_on_comm_stream=previous_event is not None,
|
allocate_on_comm_stream=previous_event is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
||||||
# However, doing this would incur an unknown synchronization error, but keeping
|
# However, doing this would incur an unknown synchronization error, but keeping
|
||||||
# `handle` as a member variable works.
|
# `handle` as a member variable works.
|
||||||
@@ -412,7 +419,6 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
|
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
|
||||||
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
||||||
num_recv_tokens_per_expert,
|
num_recv_tokens_per_expert,
|
||||||
num_tokens_per_rank=num_tokens_per_rank,
|
num_tokens_per_rank=num_tokens_per_rank,
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class MoeRunnerBackend(Enum):
|
|||||||
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
||||||
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
||||||
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
|
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
|
||||||
|
CUTLASS = "cutlass"
|
||||||
|
|
||||||
def is_auto(self):
|
def is_auto(self):
|
||||||
return self == MoeRunnerBackend.AUTO
|
return self == MoeRunnerBackend.AUTO
|
||||||
@@ -80,6 +81,9 @@ class MoeRunnerBackend(Enum):
|
|||||||
def is_flashinfer_mxfp4(self):
|
def is_flashinfer_mxfp4(self):
|
||||||
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
||||||
|
|
||||||
|
def is_cutlass(self):
|
||||||
|
return self == MoeRunnerBackend.CUTLASS
|
||||||
|
|
||||||
|
|
||||||
class DeepEPMode(Enum):
|
class DeepEPMode(Enum):
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
@@ -21,8 +21,10 @@ from sglang.srt.utils import is_npu, set_weight_attrs
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
CombineInput,
|
CombineInput,
|
||||||
|
DeepEPNormalOutput,
|
||||||
StandardDispatchOutput,
|
StandardDispatchOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -326,3 +328,47 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||||
output *= self.moe_runner_config.routed_scaling_factor
|
output *= self.moe_runner_config.routed_scaling_factor
|
||||||
return StandardCombineInput(hidden_states=output)
|
return StandardCombineInput(hidden_states=output)
|
||||||
|
|
||||||
|
def apply_deepep_normal(
|
||||||
|
self,
|
||||||
|
layer: DeepEPMoE,
|
||||||
|
dispatch_output: DeepEPNormalOutput,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.moe.cutlass_w4a8_moe import (
|
||||||
|
cutlass_w4a8_moe_deepep_normal,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, topk_idx, topk_weights = (
|
||||||
|
dispatch_output.hidden_states,
|
||||||
|
dispatch_output.topk_idx,
|
||||||
|
dispatch_output.topk_weights,
|
||||||
|
)
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
hidden_states = hidden_states[0]
|
||||||
|
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
if num_tokens > 0:
|
||||||
|
return cutlass_w4a8_moe_deepep_normal(
|
||||||
|
hidden_states,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
layer.w13_weight_scale_inv,
|
||||||
|
layer.w2_weight_scale_inv,
|
||||||
|
topk_weights,
|
||||||
|
topk_idx,
|
||||||
|
self.a_strides1,
|
||||||
|
self.b_strides1,
|
||||||
|
self.c_strides1,
|
||||||
|
self.a_strides2,
|
||||||
|
self.b_strides2,
|
||||||
|
self.c_strides2,
|
||||||
|
self.s_strides13,
|
||||||
|
self.s_strides2,
|
||||||
|
self.expert_offsets,
|
||||||
|
self.problem_sizes1,
|
||||||
|
self.problem_sizes2,
|
||||||
|
layer.w13_input_scale,
|
||||||
|
layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|||||||
@@ -137,6 +137,7 @@ MOE_RUNNER_BACKEND_CHOICES = [
|
|||||||
"flashinfer_cutlass",
|
"flashinfer_cutlass",
|
||||||
"flashinfer_mxfp4",
|
"flashinfer_mxfp4",
|
||||||
"flashinfer_cutedsl",
|
"flashinfer_cutedsl",
|
||||||
|
"cutlass",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -118,5 +118,60 @@ class TestDeepseekV3W4Afp8Mtp(CustomTestCase):
|
|||||||
self.assertGreater(avg_spec_accept_length, 2.9)
|
self.assertGreater(avg_spec_accept_length, 2.9)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeepseekV3W4Afp8DeepepNormal(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = try_cached_model(DEFAULT_DEEPSEEK_W4AFP8_MODEL_FOR_TEST)
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
other_args = [
|
||||||
|
"--tp",
|
||||||
|
"8",
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--ep-size",
|
||||||
|
"8",
|
||||||
|
"--cuda-graph-bs",
|
||||||
|
"256",
|
||||||
|
"--disable-radix-cache",
|
||||||
|
"--moe-a2a-backend",
|
||||||
|
"deepep",
|
||||||
|
"--deepep-mode",
|
||||||
|
"normal",
|
||||||
|
"--dp",
|
||||||
|
"8",
|
||||||
|
"--enable-dp-attention",
|
||||||
|
"--moe-runner-backend",
|
||||||
|
"cutlass",
|
||||||
|
]
|
||||||
|
if not is_in_amd_ci():
|
||||||
|
other_args += ["--mem-frac", "0.7"]
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(f"Eval accuracy of GSM8K: {metrics=}")
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.92)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user