adapt w4a8 marlin deepep dp ep
This commit is contained in:
384
python/sglang/srt/layers/moe/ep_moe/layer.py
Normal file → Executable file
384
python/sglang/srt/layers/moe/ep_moe/layer.py
Normal file → Executable file
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||
import torch
|
||||
|
||||
from sglang.srt import single_batch_overlap
|
||||
@@ -54,7 +55,286 @@ if _use_aiter:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepEPMoE(FusedMoE):
|
||||
# TODO(kaixih@nvidia): ideally we should merge this logic into
|
||||
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
|
||||
@torch.compile
|
||||
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
|
||||
temp = x.to(torch.float32).view(torch.int32)
|
||||
exp = torch.bitwise_right_shift(temp, 23)
|
||||
mant = torch.bitwise_and(temp, 0x7FFFFF)
|
||||
is_ru = torch.logical_and(
|
||||
torch.logical_and((mant > 0), (exp != 0xFE)),
|
||||
~torch.logical_and((exp == 0), (mant <= 0x400000)),
|
||||
)
|
||||
exp = torch.where(is_ru, exp + 1, exp)
|
||||
new_x = exp.to(torch.uint8).view(torch.int)
|
||||
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
|
||||
|
||||
class EPMoE(FusedMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
num_fused_shared_experts: int = 0,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_clamp_limit: Optional[float] = None,
|
||||
with_bias: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
layer_id=layer_id,
|
||||
top_k=top_k,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
activation=activation,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
gemm1_alpha=gemm1_alpha,
|
||||
gemm1_clamp_limit=gemm1_clamp_limit,
|
||||
with_bias=with_bias,
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.block_shape = (
|
||||
self.quant_method.quant_config.weight_block_size
|
||||
if self.use_block_quant
|
||||
else None
|
||||
)
|
||||
self.use_fp8_w8a8 = True
|
||||
self.fp8_dtype = torch.float8_e4m3fn
|
||||
self.activation_scheme = quant_config.activation_scheme
|
||||
self.use_w4a8_marlin = False
|
||||
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
|
||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||
self.block_shape = (
|
||||
self.quant_method.quant_config.weight_block_size
|
||||
if self.use_block_quant
|
||||
else None
|
||||
)
|
||||
self.use_fp8_w8a8 = False
|
||||
self.activation_scheme = None
|
||||
self.use_w4a8_marlin = True
|
||||
else:
|
||||
self.use_fp8_w8a8 = False
|
||||
self.use_block_quant = False
|
||||
self.block_shape = None
|
||||
self.activation_scheme = None
|
||||
self.use_w4a8_marlin = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm(hidden_states, topk_output)
|
||||
else:
|
||||
return super().forward(hidden_states, topk_output)
|
||||
|
||||
def forward_deepgemm(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
if not self.use_block_quant:
|
||||
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
||||
scale_block_size = 128
|
||||
w13_weight_scale_n = 2 * (
|
||||
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
||||
)
|
||||
w13_weight_scale_k = (
|
||||
hidden_states_shape[-1] + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w13_weight_scale = (
|
||||
self.w13_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w13_weight_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w13_weight_scale_k, dim=2)
|
||||
)
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
w13_weight_scale,
|
||||
)
|
||||
w2_weight_scale_n = (
|
||||
hidden_states_shape[-1] + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w2_weight_scale_k = (
|
||||
self.intermediate_size + scale_block_size - 1
|
||||
) // scale_block_size
|
||||
w2_weight_scale = (
|
||||
self.w2_weight_scale.unsqueeze(1)
|
||||
.repeat_interleave(w2_weight_scale_n, dim=1)
|
||||
.unsqueeze(2)
|
||||
.repeat_interleave(w2_weight_scale_k, dim=2)
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
w2_weight_scale,
|
||||
)
|
||||
|
||||
# PreReorder
|
||||
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
||||
moe_ep_deepgemm_preprocess(
|
||||
topk_ids,
|
||||
self.num_experts,
|
||||
hidden_states,
|
||||
self.top_k,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.block_shape,
|
||||
)
|
||||
)
|
||||
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
b, s_mn, s_k = gateup_input_scale.shape
|
||||
assert (
|
||||
s_mn % 4 == 0 and s_k % 4 == 0
|
||||
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_input_fp8 = (
|
||||
gateup_input,
|
||||
(
|
||||
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
||||
gateup_input_scale
|
||||
)
|
||||
),
|
||||
)
|
||||
num_groups, m, k = gateup_input_fp8[0].size()
|
||||
n = self.w13_weight.size(1)
|
||||
gateup_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
gateup_input_fp8,
|
||||
self.w13_weight_fp8,
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
del gateup_input
|
||||
del gateup_input_fp8
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=self.fp8_dtype,
|
||||
)
|
||||
scale_block_size = 128
|
||||
down_input_scale = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=hidden_states_device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
silu_and_mul_masked_post_quant_fwd(
|
||||
gateup_output,
|
||||
down_input,
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
del gateup_output
|
||||
|
||||
# GroupGemm-1
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
(
|
||||
down_input_scale
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
||||
),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
down_input_fp8,
|
||||
self.w2_weight_fp8,
|
||||
down_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
del down_input
|
||||
del down_input_fp8
|
||||
|
||||
# PostReorder
|
||||
output = torch.empty(
|
||||
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
||||
)
|
||||
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
||||
down_output,
|
||||
output,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
self.top_k,
|
||||
hidden_states_shape[1],
|
||||
m_max * self.start_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
if self.moe_runner_config.routed_scaling_factor is not None:
|
||||
output *= self.moe_runner_config.routed_scaling_factor
|
||||
return output
|
||||
|
||||
|
||||
class DeepEPMoE(EPMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||
Mooncake EP shares the same class, as they expose the same interface.
|
||||
@@ -106,11 +386,28 @@ class DeepEPMoE(FusedMoE):
|
||||
|
||||
self.deepep_mode = get_deepep_mode()
|
||||
|
||||
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||
# NPU supports low_latency deepep without deepgemm
|
||||
assert (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
# TODO: move to the beginning of the file
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
params_dtype=params_dtype,
|
||||
deepep_mode=self.deepep_mode,
|
||||
async_finish=True, # TODO
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
# if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||
# # NPU supports low_latency deepep without deepgemm
|
||||
# assert (
|
||||
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
if _use_aiter:
|
||||
# expert_mask is of size (self.num_local_experts + 1),
|
||||
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
||||
@@ -124,23 +421,23 @@ class DeepEPMoE(FusedMoE):
|
||||
)
|
||||
# the last one is invalid rank_id
|
||||
self.expert_mask[:-1] = 1
|
||||
elif not _is_npu:
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant or self.use_w4afp8
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
(
|
||||
self.w2_weight_scale_inv
|
||||
if self.use_block_quant or self.use_w4afp8
|
||||
else self.w2_weight_scale
|
||||
),
|
||||
)
|
||||
# elif not _is_npu:
|
||||
# self.w13_weight_fp8 = (
|
||||
# self.w13_weight,
|
||||
# (
|
||||
# self.w13_weight_scale_inv
|
||||
# if self.use_block_quant
|
||||
# else self.w13_weight_scale
|
||||
# ),
|
||||
# )
|
||||
# self.w2_weight_fp8 = (
|
||||
# self.w2_weight,
|
||||
# (
|
||||
# self.w2_weight_scale_inv
|
||||
# if self.use_block_quant
|
||||
# else self.w2_weight_scale
|
||||
# ),
|
||||
# )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -187,10 +484,15 @@ class DeepEPMoE(FusedMoE):
|
||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||
return self.forward_npu(dispatch_output)
|
||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||
if self.use_w4afp8:
|
||||
return self.forward_cutlass_w4afp8(dispatch_output)
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif self.use_w4a8_marlin:
|
||||
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dispatch output is not supported"
|
||||
)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
if (
|
||||
get_moe_runner_backend().is_flashinfer_cutedsl()
|
||||
@@ -255,6 +557,34 @@ class DeepEPMoE(FusedMoE):
|
||||
expert_mask=self.expert_mask,
|
||||
)
|
||||
|
||||
def forward_deepgemm_w4a8_marlin_contiguous(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
):
|
||||
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
||||
dispatch_output
|
||||
)
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
# if num_recv_tokens_per_expert is None:
|
||||
return hidden_states_int8.bfloat16()
|
||||
# expert_output = self.quant_method.apply_ep(
|
||||
# layer=self,
|
||||
# x=dispatch_output,
|
||||
# topk_weights=topk_weights,
|
||||
# topk_ids=topk_idx,
|
||||
# global_num_experts=self.global_num_experts,
|
||||
# expert_map=self.expert_map,
|
||||
# activation=self.activation,
|
||||
# apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
# use_nn_moe=self.use_nn_moe,
|
||||
# num_local_tokens=dispatch_recv_num_token,
|
||||
# config_select_bs=hidden_states.shape[0],
|
||||
# scales=dispatch_scales if self.use_int8_dispatch else None
|
||||
# # routed_scaling_factor=self.routed_scaling_factor,
|
||||
# )
|
||||
# return expert_output
|
||||
|
||||
def forward_deepgemm_contiguous(
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
|
||||
31
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Normal file → Executable file
31
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
Normal file → Executable file
@@ -460,11 +460,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
else:
|
||||
raise NotImplementedError() # triton runner was supported but it's temporarily disabled
|
||||
|
||||
#if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
# else:
|
||||
# if hidden_states.shape[0] > 0:
|
||||
# num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
# output = torch.empty(
|
||||
# (num_tokens, hidden_states.shape[1]),
|
||||
# device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype,
|
||||
# )
|
||||
# deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
# hidden_states,
|
||||
# output,
|
||||
# self.src2dst,
|
||||
# topk_idx,
|
||||
# topk_weights,
|
||||
# self.router_topk,
|
||||
# hidden_states.shape[1],
|
||||
# BLOCK_SIZE=512,
|
||||
# )
|
||||
# else:
|
||||
# output = torch.zeros(
|
||||
# (0, hidden_states.shape[1]),
|
||||
# device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype,
|
||||
# )
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return output, previous_event
|
||||
|
||||
|
||||
Reference in New Issue
Block a user