[Feature] Support DeepEP Low Latency (#4767)
Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
||||
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
|
||||
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
|
||||
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
|
||||
* `deepep_mode`: Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.
|
||||
|
||||
## Memory and scheduling
|
||||
|
||||
|
||||
@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
|
||||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
||||
|
||||
|
||||
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
||||
@triton.jit
|
||||
def _silu_and_mul_post_quant_kernel(
|
||||
input_ptr,
|
||||
stride_input_0,
|
||||
stride_input_1,
|
||||
stride_input_2,
|
||||
output_ptr,
|
||||
stride_output_0,
|
||||
stride_output_1,
|
||||
stride_output_2,
|
||||
output_scale_ptr,
|
||||
stride_output_scale_0,
|
||||
stride_output_scale_1,
|
||||
stride_output_scale_2,
|
||||
masked_m_ptr,
|
||||
size_n,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
BLOCK_N: tl.constexpr,
|
||||
NUM_STAGE: tl.constexpr,
|
||||
):
|
||||
expert_id = tl.program_id(2)
|
||||
token_id = tl.program_id(1)
|
||||
hidden_dim_block_index = tl.program_id(0)
|
||||
|
||||
block_num_per_expert = tl.num_programs(1)
|
||||
|
||||
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
|
||||
|
||||
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
|
||||
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
|
||||
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
|
||||
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
|
||||
|
||||
offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
|
||||
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
|
||||
output_scale_offs = (
|
||||
output_scale_ptr
|
||||
+ expert_id * stride_output_scale_0
|
||||
+ hidden_dim_block_index * stride_output_scale_2
|
||||
)
|
||||
|
||||
for token_index in tl.range(
|
||||
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
|
||||
):
|
||||
gate = tl.load(
|
||||
input_ptr_offs + token_index * stride_input_1,
|
||||
mask=offs_in_d < size_n,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
up = tl.load(
|
||||
input_ptr_offs + token_index * stride_input_1 + size_n,
|
||||
mask=offs_in_d < size_n,
|
||||
other=0.0,
|
||||
)
|
||||
gate = gate / (1 + tl.exp(-gate))
|
||||
gate = gate.to(input_ptr.dtype.element_ty)
|
||||
gate_up = up * gate
|
||||
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
||||
output_s = _absmax / fp8_max
|
||||
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
|
||||
output_ptr.dtype.element_ty
|
||||
)
|
||||
tl.store(
|
||||
output_ptr_offs + token_index * stride_output_1,
|
||||
output_q,
|
||||
mask=offs_in_d < size_n,
|
||||
)
|
||||
tl.store(
|
||||
output_scale_offs + token_index * stride_output_scale_1,
|
||||
output_s,
|
||||
)
|
||||
|
||||
|
||||
def silu_and_mul_masked_post_quant_fwd(
|
||||
input: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
quant_group_size: int,
|
||||
masked_m: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
input shape [expert_num, token_num_padded, hidden_dim]
|
||||
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
|
||||
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
|
||||
quant_group_size int,
|
||||
masked_m shape [expert_num],
|
||||
"""
|
||||
|
||||
assert input.is_contiguous()
|
||||
assert output.dtype == torch.float8_e4m3fn
|
||||
assert output.is_contiguous()
|
||||
assert len(input.shape) == 3
|
||||
assert input.shape[0] == masked_m.shape[0]
|
||||
assert input.shape[-1] % 2 == 0
|
||||
|
||||
size_n = input.shape[-1] // 2
|
||||
assert size_n % quant_group_size == 0
|
||||
|
||||
expert_num = len(masked_m)
|
||||
|
||||
if expert_num < 4:
|
||||
BLOCK_NUM_PER_EXPERT = 64
|
||||
else:
|
||||
BLOCK_NUM_PER_EXPERT = 32
|
||||
|
||||
BLOCK_N = quant_group_size
|
||||
num_warps = 1
|
||||
NUM_STAGES = 6
|
||||
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
|
||||
assert BLOCK_N % quant_group_size == 0
|
||||
|
||||
grid = (
|
||||
hidden_dim_split_block_num,
|
||||
BLOCK_NUM_PER_EXPERT,
|
||||
expert_num,
|
||||
)
|
||||
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = finfo.max
|
||||
fp8_min = -fp8_max
|
||||
|
||||
_silu_and_mul_post_quant_kernel[grid](
|
||||
input,
|
||||
*input.stride(),
|
||||
output,
|
||||
*output.stride(),
|
||||
output_scale,
|
||||
*output_scale.stride(),
|
||||
masked_m,
|
||||
size_n,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NUM_STAGE=NUM_STAGES,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# TODO: use deep_gemm masked kernel after low latency dispatch
|
||||
# import deep_gemm
|
||||
# from deep_gemm import (
|
||||
# get_col_major_tma_aligned_tensor,
|
||||
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
# )
|
||||
try:
|
||||
from deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
)
|
||||
|
||||
use_deep_gemm = True
|
||||
except ImportError:
|
||||
use_deep_gemm = False
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
silu_and_mul_triton_kernel,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
@@ -809,6 +814,7 @@ class DeepEPMoE(EPMoE):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
deepep_mode: str = "auto",
|
||||
):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
@@ -827,21 +833,41 @@ class DeepEPMoE(EPMoE):
|
||||
custom_routing_function,
|
||||
activation,
|
||||
)
|
||||
self.deepep_mode = deepep_mode
|
||||
if self.deepep_mode in ["low_latency", "auto"]:
|
||||
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||
self.w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
self.w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
forward_mode: ForwardMode,
|
||||
):
|
||||
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
||||
if True: # not forward_mode.is_decode():
|
||||
if self.deepep_mode == "normal" or (
|
||||
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
||||
):
|
||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||
elif self.deepep_mode == "low_latency" or (
|
||||
self.deepep_mode == "auto" and forward_mode.is_decode()
|
||||
):
|
||||
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
||||
else:
|
||||
return self.forward_deepgemm_masked(
|
||||
hidden_states, reorder_topk_ids, seg_indptr
|
||||
)
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
@@ -958,89 +984,66 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
|
||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
assert (
|
||||
hidden_states_fp8[0].size(0) % 4 == 0
|
||||
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
|
||||
|
||||
# GroupGemm-0
|
||||
num_groups, m, k = hidden_states_fp8[0].size()
|
||||
n = self.w13_weight.size(1)
|
||||
expected_m = min(expected_m, m)
|
||||
gateup_output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||
)
|
||||
if hidden_states.shape[0] > 0:
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
hidden_states = (
|
||||
hidden_states[0],
|
||||
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
||||
)
|
||||
"""
|
||||
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states, self.w13_weight, out, masked_m, expected_m
|
||||
)
|
||||
"""
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
0,
|
||||
self.num_experts_per_partition - 1,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
gateup_output.shape[2] // 2,
|
||||
),
|
||||
device=gateup_output.device,
|
||||
dtype=self.fp8_dtype,
|
||||
)
|
||||
scale_block_size = 128
|
||||
down_input_scale = torch.empty(
|
||||
(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1],
|
||||
gateup_output.shape[2] // 2 // scale_block_size,
|
||||
),
|
||||
device=gateup_output.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
silu_and_mul_masked_post_quant_fwd(
|
||||
gateup_output,
|
||||
down_input,
|
||||
down_input_scale,
|
||||
scale_block_size,
|
||||
masked_m,
|
||||
)
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
n = self.w2_weight.size(1)
|
||||
down_input_fp8 = (
|
||||
down_input,
|
||||
get_col_major_tma_aligned_tensor(down_input_scale),
|
||||
)
|
||||
down_output = torch.empty(
|
||||
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
||||
)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
||||
)
|
||||
if down_input.shape[0] > 0:
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
down_input = (
|
||||
down_input[0],
|
||||
get_col_major_tma_aligned_tensor(down_input[1]),
|
||||
)
|
||||
"""
|
||||
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
down_input, self.w2_weight, out, masked_m, expected_m
|
||||
)
|
||||
"""
|
||||
|
||||
return down_output
|
||||
|
||||
@@ -76,8 +76,7 @@ def get_buffer_low_latency(
|
||||
assert num_experts % group.size() == 0
|
||||
_buffer_low_latency = Buffer(
|
||||
group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_experts // group.size(),
|
||||
)
|
||||
@@ -95,62 +94,63 @@ class DeepEPDispatcher:
|
||||
group: torch.distributed.ProcessGroup,
|
||||
router_topk: int,
|
||||
permute_fusion: bool = False,
|
||||
capacity_factor: float = None,
|
||||
num_experts: int = None,
|
||||
num_local_experts: int = None,
|
||||
hidden_size: int = None,
|
||||
params_dtype: torch.dtype = None,
|
||||
deepep_mode: str = "auto",
|
||||
async_finish: bool = False,
|
||||
return_recv_hook: bool = False,
|
||||
):
|
||||
self.group = group
|
||||
self.router_topk = router_topk
|
||||
self.capacity_factor = capacity_factor
|
||||
self.permute_fusion = permute_fusion
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_local_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.recv_expert_count = None
|
||||
self.params_dtype = params_dtype
|
||||
self.params_bytes = 2
|
||||
# Metadata
|
||||
self.token_indices = None
|
||||
self.token_probs = None
|
||||
# Handle used for combine operation
|
||||
self.handle = None
|
||||
self.async_finish = async_finish
|
||||
|
||||
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
||||
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
self.num_max_dispatch_tokens_per_rank = 128
|
||||
|
||||
if not use_deepep:
|
||||
raise ImportError(
|
||||
"DeepEP is not installed. Please install DeepEP package from "
|
||||
"https://github.com/deepseek-ai/deepep."
|
||||
)
|
||||
self.buffer_normal = get_buffer_normal(
|
||||
self.group, self.hidden_size * self.params_bytes
|
||||
)
|
||||
self.buffer_low_latency = None
|
||||
# Todo: enable low latency dispatch
|
||||
"""
|
||||
self.buffer_low_latency = get_buffer_low_latency(
|
||||
self.group,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size * self.params_bytes,
|
||||
self.num_experts,
|
||||
)
|
||||
"""
|
||||
|
||||
self.group = group
|
||||
self.router_topk = router_topk
|
||||
self.permute_fusion = permute_fusion
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_local_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.params_dtype = params_dtype
|
||||
self.params_bytes = 2
|
||||
|
||||
self.deepep_mode = deepep_mode
|
||||
self.handle = None
|
||||
|
||||
if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode
|
||||
self.buffer_normal = get_buffer_normal(
|
||||
self.group, self.hidden_size * self.params_bytes
|
||||
)
|
||||
self.async_finish = async_finish
|
||||
self.src2dst = None
|
||||
if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode
|
||||
"""
|
||||
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
"""
|
||||
# TODO(ch-wan): allow users to set this value
|
||||
self.num_max_dispatch_tokens_per_rank = 128
|
||||
self.buffer_low_latency = get_buffer_low_latency(
|
||||
self.group,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size,
|
||||
self.num_experts,
|
||||
)
|
||||
self.return_recv_hook = return_recv_hook
|
||||
|
||||
def deepep_permute(
|
||||
self,
|
||||
hidden_states,
|
||||
fp8_dtype=None,
|
||||
use_fp8_w8a8=False,
|
||||
use_block_quant=False,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
fp8_dtype: Optional[torch.dtype] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_block_quant: bool = False,
|
||||
):
|
||||
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||
self.topk_idx, self.num_experts
|
||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||
topk_idx, self.num_experts
|
||||
)
|
||||
num_total_tokens = reorder_topk_ids.numel()
|
||||
gateup_input = torch.empty(
|
||||
@@ -166,14 +166,13 @@ class DeepEPDispatcher:
|
||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
self.topk_idx,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
None,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
self.src2dst = src2dst
|
||||
return reorder_topk_ids, seg_indptr, gateup_input
|
||||
|
||||
def dispatch(
|
||||
@@ -182,54 +181,64 @@ class DeepEPDispatcher:
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
forward_mode: ForwardMode,
|
||||
num_max_dispatch_tokens_per_rank: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
forward_mode: ForwardMode = None,
|
||||
) -> Tuple:
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
# Todo: enable low latency dispatch
|
||||
if True: # not forward_mode.is_decode():
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
masked_m = torch.empty(
|
||||
(self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
expected_m = 0
|
||||
|
||||
if self.deepep_mode == "normal" or (
|
||||
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
||||
):
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
||||
self.tokens_per_expert = torch.tensor(
|
||||
num_recv_tokens_per_expert_list,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
else:
|
||||
hidden_states, recv_expert_count, handle, event, hook = (
|
||||
self.dispatch_low_latency(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
if hidden_states.shape[0] > 0:
|
||||
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
||||
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||
)
|
||||
elif self.deepep_mode == "low_latency" or (
|
||||
self.deepep_mode == "auto" and forward_mode.is_decode()
|
||||
):
|
||||
expected_m = (
|
||||
hidden_states.shape[0]
|
||||
* self.buffer_low_latency.group_size
|
||||
* topk_idx.shape[1]
|
||||
+ num_experts
|
||||
) // num_experts
|
||||
hidden_states, masked_m, event, hook = self.dispatch_low_latency(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=True,
|
||||
)
|
||||
self.recv_expert_count = recv_expert_count
|
||||
|
||||
if self.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
self.handle = handle
|
||||
self.topk_idx = topk_idx
|
||||
self.topk_weights = topk_weights
|
||||
if hidden_states.shape[0] > 0:
|
||||
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
||||
hidden_states, fp8_dtype=hidden_states.dtype
|
||||
)
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
else:
|
||||
reorder_topk_ids = torch.empty(
|
||||
(0,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
seg_indptr = torch.zeros(
|
||||
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||
)
|
||||
return hidden_states, reorder_topk_ids, seg_indptr
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
def dispatch_normal(
|
||||
self,
|
||||
@@ -254,12 +263,15 @@ class DeepEPDispatcher:
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
)
|
||||
|
||||
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
||||
# However, doing this would incur an unknown synchronization error, but keeping
|
||||
# `handle` as a member variable works.
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
_, # num_recv_tokens_per_expert_list
|
||||
self.handle,
|
||||
event,
|
||||
) = self.buffer_normal.dispatch(
|
||||
x,
|
||||
@@ -278,8 +290,6 @@ class DeepEPDispatcher:
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
)
|
||||
|
||||
@@ -289,18 +299,19 @@ class DeepEPDispatcher:
|
||||
topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
num_experts: int,
|
||||
use_fp8: bool = False,
|
||||
):
|
||||
"""
|
||||
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
|
||||
# Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall!
|
||||
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
|
||||
# Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
|
||||
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
||||
+
|
||||
|
||||
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
||||
index f60e933..cddaabf 100644
|
||||
index 76ae2e2..8ecd08f 100644
|
||||
--- a/csrc/kernels/internode_ll.cu
|
||||
+++ b/csrc/kernels/internode_ll.cu
|
||||
@@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
@@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
- constexpr int kNumWarpsPerGroup = 10;
|
||||
@@ -308,16 +319,9 @@ class DeepEPDispatcher:
|
||||
+ constexpr int kNumWarpsPerGroup = 8;
|
||||
+ constexpr int kNumWarpGroups = 4;
|
||||
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
||||
+
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
- EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
||||
+ // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
||||
+
|
||||
// Workspace checks
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
@@ -505,8 +505,8 @@ void combine(void* combined_x,
|
||||
@@ -501,8 +501,8 @@ void combine(void* combined_x,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
@@ -326,28 +330,33 @@ class DeepEPDispatcher:
|
||||
+ constexpr int kNumWarpsPerGroup = 8;
|
||||
+ constexpr int kNumWarpGroups = 4;
|
||||
constexpr int kNumMaxTopk = 9;
|
||||
+
|
||||
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
"""
|
||||
|
||||
recv_hidden_states, recv_expert_count, handle, event, hook = (
|
||||
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
||||
self.buffer_low_latency.low_latency_dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
async_finish=self.async_finish,
|
||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
||||
use_fp8=use_fp8,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
)
|
||||
)
|
||||
# hook()
|
||||
return recv_hidden_states, recv_expert_count, handle, event, hook
|
||||
return packed_recv_hidden, packed_recv_count, event, hook
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Todo: enable low latency combine
|
||||
if True: # not forward_mode.is_decode():
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_mode: ForwardMode,
|
||||
) -> torch.Tensor:
|
||||
if self.deepep_mode == "normal" or (
|
||||
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
||||
):
|
||||
if hidden_states.shape[0] > 0:
|
||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
output = torch.empty(
|
||||
@@ -359,8 +368,8 @@ class DeepEPDispatcher:
|
||||
hidden_states,
|
||||
output,
|
||||
self.src2dst,
|
||||
self.topk_idx,
|
||||
self.topk_weights,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
@@ -371,24 +380,30 @@ class DeepEPDispatcher:
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states, event = self.combine_normal(output, self.handle)
|
||||
else:
|
||||
hidden_states, event, hook = self.combine_low_latency(
|
||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
||||
hidden_states, event = self.combine_normal(
|
||||
output,
|
||||
)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
elif self.deepep_mode == "low_latency" or (
|
||||
self.deepep_mode == "auto" and forward_mode.is_decode()
|
||||
):
|
||||
hidden_states, event, hook = self.combine_low_latency(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
)
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
else:
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
|
||||
if self.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
self.handle = None
|
||||
return hidden_states
|
||||
|
||||
def combine_normal(self, x: torch.Tensor, handle: Tuple):
|
||||
def combine_normal(self, x: torch.Tensor):
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
|
||||
combined_x, _, event = self.buffer_normal.combine(
|
||||
x,
|
||||
handle,
|
||||
self.handle,
|
||||
async_finish=self.async_finish,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
@@ -400,17 +415,15 @@ class DeepEPDispatcher:
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
handle: Tuple,
|
||||
):
|
||||
combined_hidden_states, event_overlap, hook = (
|
||||
combined_hidden_states, event, hook = (
|
||||
self.buffer_low_latency.low_latency_combine(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
async_finish=self.async_finish,
|
||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
||||
self.handle,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
)
|
||||
)
|
||||
# hook()
|
||||
return combined_hidden_states, event_overlap, hook
|
||||
return combined_hidden_states, event, hook
|
||||
|
||||
@@ -72,6 +72,7 @@ global_server_args_dict = {
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
||||
"deepep_mode": ServerArgs.deepep_mode,
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
|
||||
@@ -147,6 +147,7 @@ class ModelRunner:
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"enable_deepep_moe": server_args.enable_deepep_moe,
|
||||
"deepep_mode": server_args.deepep_mode,
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
@@ -272,7 +273,7 @@ class ModelRunner:
|
||||
server_args.disable_radix_cache = True
|
||||
|
||||
if server_args.enable_deepep_moe:
|
||||
logger.info("DeepEP is turned on.")
|
||||
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
@@ -188,19 +188,35 @@ class DeepseekV2MoE(nn.Module):
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
else:
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||
)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
@@ -227,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_tensor_model_parallel_world_size()
|
||||
self.num_experts = config.n_routed_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.renormalize = config.norm_topk_prob
|
||||
@@ -246,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||
async_finish=True, # TODO
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -301,28 +321,39 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
recv_hidden_states, reorder_topk_ids, seg_indptr = (
|
||||
self.deepep_dispatcher.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.num_experts,
|
||||
forward_mode,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
) = self.deepep_dispatcher.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.num_experts,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
final_hidden_states = (
|
||||
self.experts(
|
||||
hidden_states=recv_hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
* self.routed_scaling_factor
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
if self.ep_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
final_hidden_states, forward_mode
|
||||
final_hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
forward_mode,
|
||||
)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
@@ -161,6 +161,7 @@ class ServerArgs:
|
||||
enable_dp_attention: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
deepep_mode: Optional[str] = "auto"
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
@@ -285,6 +286,13 @@ class ServerArgs:
|
||||
if self.grammar_backend is None:
|
||||
self.grammar_backend = "xgrammar"
|
||||
|
||||
# Expert parallelism
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
logger.info(
|
||||
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
|
||||
# Data parallelism attention
|
||||
if self.enable_dp_attention:
|
||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||
@@ -300,6 +308,10 @@ class ServerArgs:
|
||||
self.enable_sp_layernorm = False
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
if self.deepep_mode == "auto":
|
||||
assert (
|
||||
not self.enable_dp_attention
|
||||
), "DeepEP MoE `auto` mode is not supported with DP Attention."
|
||||
self.ep_size = self.tp_size
|
||||
self.enable_sp_layernorm = (
|
||||
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
||||
@@ -1082,6 +1094,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enabling DeepEP MoE implementation for EP MoE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deepep-mode",
|
||||
type=str,
|
||||
choices=["normal", "low_latency", "auto"],
|
||||
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
||||
)
|
||||
|
||||
# Server warmups
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user