[Feature] Support DeepEP Low Latency (#4767)
Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
|||||||
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
|
* `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models.
|
||||||
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
|
* `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`.
|
||||||
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
|
* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP.
|
||||||
|
* `deepep_mode`: Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.
|
||||||
|
|
||||||
## Memory and scheduling
|
## Memory and scheduling
|
||||||
|
|
||||||
|
|||||||
@@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel(
|
|||||||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
||||||
|
@triton.jit
|
||||||
|
def _silu_and_mul_post_quant_kernel(
|
||||||
|
input_ptr,
|
||||||
|
stride_input_0,
|
||||||
|
stride_input_1,
|
||||||
|
stride_input_2,
|
||||||
|
output_ptr,
|
||||||
|
stride_output_0,
|
||||||
|
stride_output_1,
|
||||||
|
stride_output_2,
|
||||||
|
output_scale_ptr,
|
||||||
|
stride_output_scale_0,
|
||||||
|
stride_output_scale_1,
|
||||||
|
stride_output_scale_2,
|
||||||
|
masked_m_ptr,
|
||||||
|
size_n,
|
||||||
|
fp8_max,
|
||||||
|
fp8_min,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
NUM_STAGE: tl.constexpr,
|
||||||
|
):
|
||||||
|
expert_id = tl.program_id(2)
|
||||||
|
token_id = tl.program_id(1)
|
||||||
|
hidden_dim_block_index = tl.program_id(0)
|
||||||
|
|
||||||
|
block_num_per_expert = tl.num_programs(1)
|
||||||
|
|
||||||
|
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
|
||||||
|
|
||||||
|
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
|
||||||
|
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
|
||||||
|
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
|
||||||
|
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
|
||||||
|
|
||||||
|
offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||||
|
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
|
||||||
|
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
|
||||||
|
output_scale_offs = (
|
||||||
|
output_scale_ptr
|
||||||
|
+ expert_id * stride_output_scale_0
|
||||||
|
+ hidden_dim_block_index * stride_output_scale_2
|
||||||
|
)
|
||||||
|
|
||||||
|
for token_index in tl.range(
|
||||||
|
token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE
|
||||||
|
):
|
||||||
|
gate = tl.load(
|
||||||
|
input_ptr_offs + token_index * stride_input_1,
|
||||||
|
mask=offs_in_d < size_n,
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
up = tl.load(
|
||||||
|
input_ptr_offs + token_index * stride_input_1 + size_n,
|
||||||
|
mask=offs_in_d < size_n,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
gate = gate / (1 + tl.exp(-gate))
|
||||||
|
gate = gate.to(input_ptr.dtype.element_ty)
|
||||||
|
gate_up = up * gate
|
||||||
|
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
||||||
|
output_s = _absmax / fp8_max
|
||||||
|
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
|
||||||
|
output_ptr.dtype.element_ty
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
output_ptr_offs + token_index * stride_output_1,
|
||||||
|
output_q,
|
||||||
|
mask=offs_in_d < size_n,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
output_scale_offs + token_index * stride_output_scale_1,
|
||||||
|
output_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def silu_and_mul_masked_post_quant_fwd(
|
||||||
|
input: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
output_scale: torch.Tensor,
|
||||||
|
quant_group_size: int,
|
||||||
|
masked_m: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
input shape [expert_num, token_num_padded, hidden_dim]
|
||||||
|
output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8
|
||||||
|
output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32
|
||||||
|
quant_group_size int,
|
||||||
|
masked_m shape [expert_num],
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert input.is_contiguous()
|
||||||
|
assert output.dtype == torch.float8_e4m3fn
|
||||||
|
assert output.is_contiguous()
|
||||||
|
assert len(input.shape) == 3
|
||||||
|
assert input.shape[0] == masked_m.shape[0]
|
||||||
|
assert input.shape[-1] % 2 == 0
|
||||||
|
|
||||||
|
size_n = input.shape[-1] // 2
|
||||||
|
assert size_n % quant_group_size == 0
|
||||||
|
|
||||||
|
expert_num = len(masked_m)
|
||||||
|
|
||||||
|
if expert_num < 4:
|
||||||
|
BLOCK_NUM_PER_EXPERT = 64
|
||||||
|
else:
|
||||||
|
BLOCK_NUM_PER_EXPERT = 32
|
||||||
|
|
||||||
|
BLOCK_N = quant_group_size
|
||||||
|
num_warps = 1
|
||||||
|
NUM_STAGES = 6
|
||||||
|
hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N)
|
||||||
|
assert BLOCK_N % quant_group_size == 0
|
||||||
|
|
||||||
|
grid = (
|
||||||
|
hidden_dim_split_block_num,
|
||||||
|
BLOCK_NUM_PER_EXPERT,
|
||||||
|
expert_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max = finfo.max
|
||||||
|
fp8_min = -fp8_max
|
||||||
|
|
||||||
|
_silu_and_mul_post_quant_kernel[grid](
|
||||||
|
input,
|
||||||
|
*input.stride(),
|
||||||
|
output,
|
||||||
|
*output.stride(),
|
||||||
|
output_scale,
|
||||||
|
*output_scale.stride(),
|
||||||
|
masked_m,
|
||||||
|
size_n,
|
||||||
|
fp8_max,
|
||||||
|
fp8_min,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
NUM_STAGE=NUM_STAGES,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def tanh(x):
|
def tanh(x):
|
||||||
return 2 * tl.sigmoid(2 * x) - 1
|
return 2 * tl.sigmoid(2 * x) - 1
|
||||||
|
|||||||
@@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# TODO: use deep_gemm masked kernel after low latency dispatch
|
try:
|
||||||
# import deep_gemm
|
from deep_gemm import (
|
||||||
# from deep_gemm import (
|
get_col_major_tma_aligned_tensor,
|
||||||
# get_col_major_tma_aligned_tensor,
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||||
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
)
|
||||||
# )
|
|
||||||
|
use_deep_gemm = True
|
||||||
|
except ImportError:
|
||||||
|
use_deep_gemm = False
|
||||||
|
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
@@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
post_reorder_triton_kernel,
|
post_reorder_triton_kernel,
|
||||||
pre_reorder_triton_kernel,
|
pre_reorder_triton_kernel,
|
||||||
run_moe_ep_preproess,
|
run_moe_ep_preproess,
|
||||||
|
silu_and_mul_masked_post_quant_fwd,
|
||||||
silu_and_mul_triton_kernel,
|
silu_and_mul_triton_kernel,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
@@ -809,6 +814,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
deepep_mode: str = "auto",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -827,21 +833,41 @@ class DeepEPMoE(EPMoE):
|
|||||||
custom_routing_function,
|
custom_routing_function,
|
||||||
activation,
|
activation,
|
||||||
)
|
)
|
||||||
|
self.deepep_mode = deepep_mode
|
||||||
|
if self.deepep_mode in ["low_latency", "auto"]:
|
||||||
|
assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||||
|
self.w13_weight_fp8 = (
|
||||||
|
self.w13_weight,
|
||||||
|
(
|
||||||
|
self.w13_weight_scale_inv
|
||||||
|
if self.use_block_quant
|
||||||
|
else self.w13_weight_scale
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.w2_weight_fp8 = (
|
||||||
|
self.w2_weight,
|
||||||
|
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
reorder_topk_ids: torch.Tensor,
|
reorder_topk_ids: torch.Tensor,
|
||||||
seg_indptr: torch.Tensor,
|
seg_indptr: torch.Tensor,
|
||||||
|
masked_m: torch.Tensor,
|
||||||
|
expected_m: int,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
):
|
):
|
||||||
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
if self.deepep_mode == "normal" or (
|
||||||
if True: # not forward_mode.is_decode():
|
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
||||||
|
):
|
||||||
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||||
|
elif self.deepep_mode == "low_latency" or (
|
||||||
|
self.deepep_mode == "auto" and forward_mode.is_decode()
|
||||||
|
):
|
||||||
|
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepgemm_masked(
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||||
hidden_states, reorder_topk_ids, seg_indptr
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward_normal(
|
def forward_normal(
|
||||||
self,
|
self,
|
||||||
@@ -958,89 +984,66 @@ class DeepEPMoE(EPMoE):
|
|||||||
|
|
||||||
def forward_deepgemm_masked(
|
def forward_deepgemm_masked(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||||
reorder_topk_ids: torch.Tensor,
|
masked_m: torch.Tensor,
|
||||||
seg_indptr: torch.Tensor,
|
expected_m: int,
|
||||||
):
|
):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.activation == "silu"
|
||||||
|
assert (
|
||||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
hidden_states_fp8[0].size(0) % 4 == 0
|
||||||
max_value = (
|
), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
|
||||||
torch.max(hidden_states)
|
|
||||||
.repeat(self.num_experts_per_partition)
|
|
||||||
.to(torch.float32)
|
|
||||||
)
|
|
||||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
|
||||||
|
|
||||||
# GroupGemm-0
|
# GroupGemm-0
|
||||||
|
num_groups, m, k = hidden_states_fp8[0].size()
|
||||||
|
n = self.w13_weight.size(1)
|
||||||
|
expected_m = min(expected_m, m)
|
||||||
gateup_output = torch.empty(
|
gateup_output = torch.empty(
|
||||||
hidden_states.shape[0],
|
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
||||||
self.w13_weight.shape[1],
|
)
|
||||||
device=hidden_states.device,
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||||
dtype=hidden_states.dtype,
|
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
||||||
)
|
)
|
||||||
if hidden_states.shape[0] > 0:
|
|
||||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
|
||||||
hidden_states = (
|
|
||||||
hidden_states[0],
|
|
||||||
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
|
||||||
hidden_states, self.w13_weight, out, masked_m, expected_m
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
down_input = torch.empty(
|
down_input = torch.empty(
|
||||||
gateup_output.shape[0],
|
(
|
||||||
gateup_output.shape[1] // 2,
|
gateup_output.shape[0],
|
||||||
device=gateup_output.device,
|
|
||||||
dtype=(
|
|
||||||
self.fp8_dtype
|
|
||||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
|
||||||
else hidden_states.dtype
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if self.w2_input_scale is None and not self.use_block_quant:
|
|
||||||
self.w2_input_scale = torch.ones(
|
|
||||||
self.num_experts_per_partition,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=hidden_states.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.activation == "silu":
|
|
||||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
|
||||||
gateup_output,
|
|
||||||
down_input,
|
|
||||||
gateup_output.shape[1],
|
gateup_output.shape[1],
|
||||||
reorder_topk_ids,
|
gateup_output.shape[2] // 2,
|
||||||
self.w2_input_scale,
|
),
|
||||||
0,
|
device=gateup_output.device,
|
||||||
self.num_experts_per_partition - 1,
|
dtype=self.fp8_dtype,
|
||||||
BLOCK_SIZE=512,
|
)
|
||||||
)
|
scale_block_size = 128
|
||||||
else:
|
down_input_scale = torch.empty(
|
||||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
(
|
||||||
|
gateup_output.shape[0],
|
||||||
|
gateup_output.shape[1],
|
||||||
|
gateup_output.shape[2] // 2 // scale_block_size,
|
||||||
|
),
|
||||||
|
device=gateup_output.device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
silu_and_mul_masked_post_quant_fwd(
|
||||||
|
gateup_output,
|
||||||
|
down_input,
|
||||||
|
down_input_scale,
|
||||||
|
scale_block_size,
|
||||||
|
masked_m,
|
||||||
|
)
|
||||||
|
|
||||||
# GroupGemm-1
|
# GroupGemm-1
|
||||||
down_output = torch.empty(
|
n = self.w2_weight.size(1)
|
||||||
down_input.shape[0],
|
down_input_fp8 = (
|
||||||
self.w2_weight.shape[1],
|
down_input,
|
||||||
device=hidden_states.device,
|
get_col_major_tma_aligned_tensor(down_input_scale),
|
||||||
dtype=hidden_states.dtype,
|
)
|
||||||
|
down_output = torch.empty(
|
||||||
|
(num_groups, m, n), device=down_input.device, dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||||
|
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
||||||
)
|
)
|
||||||
if down_input.shape[0] > 0:
|
|
||||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
|
||||||
down_input = (
|
|
||||||
down_input[0],
|
|
||||||
get_col_major_tma_aligned_tensor(down_input[1]),
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
|
||||||
down_input, self.w2_weight, out, masked_m, expected_m
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
return down_output
|
return down_output
|
||||||
|
|||||||
@@ -76,8 +76,7 @@ def get_buffer_low_latency(
|
|||||||
assert num_experts % group.size() == 0
|
assert num_experts % group.size() == 0
|
||||||
_buffer_low_latency = Buffer(
|
_buffer_low_latency = Buffer(
|
||||||
group,
|
group,
|
||||||
0,
|
num_rdma_bytes=num_rdma_bytes,
|
||||||
num_rdma_bytes,
|
|
||||||
low_latency_mode=True,
|
low_latency_mode=True,
|
||||||
num_qps_per_rank=num_experts // group.size(),
|
num_qps_per_rank=num_experts // group.size(),
|
||||||
)
|
)
|
||||||
@@ -95,62 +94,63 @@ class DeepEPDispatcher:
|
|||||||
group: torch.distributed.ProcessGroup,
|
group: torch.distributed.ProcessGroup,
|
||||||
router_topk: int,
|
router_topk: int,
|
||||||
permute_fusion: bool = False,
|
permute_fusion: bool = False,
|
||||||
capacity_factor: float = None,
|
|
||||||
num_experts: int = None,
|
num_experts: int = None,
|
||||||
num_local_experts: int = None,
|
num_local_experts: int = None,
|
||||||
hidden_size: int = None,
|
hidden_size: int = None,
|
||||||
params_dtype: torch.dtype = None,
|
params_dtype: torch.dtype = None,
|
||||||
|
deepep_mode: str = "auto",
|
||||||
async_finish: bool = False,
|
async_finish: bool = False,
|
||||||
|
return_recv_hook: bool = False,
|
||||||
):
|
):
|
||||||
self.group = group
|
|
||||||
self.router_topk = router_topk
|
|
||||||
self.capacity_factor = capacity_factor
|
|
||||||
self.permute_fusion = permute_fusion
|
|
||||||
self.num_experts = num_experts
|
|
||||||
self.num_local_experts = num_local_experts
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.recv_expert_count = None
|
|
||||||
self.params_dtype = params_dtype
|
|
||||||
self.params_bytes = 2
|
|
||||||
# Metadata
|
|
||||||
self.token_indices = None
|
|
||||||
self.token_probs = None
|
|
||||||
# Handle used for combine operation
|
|
||||||
self.handle = None
|
|
||||||
self.async_finish = async_finish
|
|
||||||
|
|
||||||
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
|
||||||
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
|
||||||
self.num_max_dispatch_tokens_per_rank = 128
|
|
||||||
|
|
||||||
if not use_deepep:
|
if not use_deepep:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"DeepEP is not installed. Please install DeepEP package from "
|
"DeepEP is not installed. Please install DeepEP package from "
|
||||||
"https://github.com/deepseek-ai/deepep."
|
"https://github.com/deepseek-ai/deepep."
|
||||||
)
|
)
|
||||||
self.buffer_normal = get_buffer_normal(
|
|
||||||
self.group, self.hidden_size * self.params_bytes
|
self.group = group
|
||||||
)
|
self.router_topk = router_topk
|
||||||
self.buffer_low_latency = None
|
self.permute_fusion = permute_fusion
|
||||||
# Todo: enable low latency dispatch
|
self.num_experts = num_experts
|
||||||
"""
|
self.num_local_experts = num_local_experts
|
||||||
self.buffer_low_latency = get_buffer_low_latency(
|
self.hidden_size = hidden_size
|
||||||
self.group,
|
self.params_dtype = params_dtype
|
||||||
self.num_max_dispatch_tokens_per_rank,
|
self.params_bytes = 2
|
||||||
self.hidden_size * self.params_bytes,
|
|
||||||
self.num_experts,
|
self.deepep_mode = deepep_mode
|
||||||
)
|
self.handle = None
|
||||||
"""
|
|
||||||
|
if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode
|
||||||
|
self.buffer_normal = get_buffer_normal(
|
||||||
|
self.group, self.hidden_size * self.params_bytes
|
||||||
|
)
|
||||||
|
self.async_finish = async_finish
|
||||||
|
self.src2dst = None
|
||||||
|
if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode
|
||||||
|
"""
|
||||||
|
num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256
|
||||||
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||||
|
"""
|
||||||
|
# TODO(ch-wan): allow users to set this value
|
||||||
|
self.num_max_dispatch_tokens_per_rank = 128
|
||||||
|
self.buffer_low_latency = get_buffer_low_latency(
|
||||||
|
self.group,
|
||||||
|
self.num_max_dispatch_tokens_per_rank,
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_experts,
|
||||||
|
)
|
||||||
|
self.return_recv_hook = return_recv_hook
|
||||||
|
|
||||||
def deepep_permute(
|
def deepep_permute(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
fp8_dtype=None,
|
topk_idx: torch.Tensor,
|
||||||
use_fp8_w8a8=False,
|
fp8_dtype: Optional[torch.dtype] = None,
|
||||||
use_block_quant=False,
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_block_quant: bool = False,
|
||||||
):
|
):
|
||||||
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||||
self.topk_idx, self.num_experts
|
topk_idx, self.num_experts
|
||||||
)
|
)
|
||||||
num_total_tokens = reorder_topk_ids.numel()
|
num_total_tokens = reorder_topk_ids.numel()
|
||||||
gateup_input = torch.empty(
|
gateup_input = torch.empty(
|
||||||
@@ -166,14 +166,13 @@ class DeepEPDispatcher:
|
|||||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||||
hidden_states,
|
hidden_states,
|
||||||
gateup_input,
|
gateup_input,
|
||||||
src2dst,
|
self.src2dst,
|
||||||
self.topk_idx,
|
topk_idx,
|
||||||
None,
|
None,
|
||||||
self.router_topk,
|
self.router_topk,
|
||||||
hidden_states.shape[1],
|
hidden_states.shape[1],
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
self.src2dst = src2dst
|
|
||||||
return reorder_topk_ids, seg_indptr, gateup_input
|
return reorder_topk_ids, seg_indptr, gateup_input
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
@@ -182,54 +181,64 @@ class DeepEPDispatcher:
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
forward_mode: ForwardMode,
|
|
||||||
num_max_dispatch_tokens_per_rank: int = 128,
|
num_max_dispatch_tokens_per_rank: int = 128,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
forward_mode: ForwardMode = None,
|
||||||
|
) -> Tuple:
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_idx = topk_idx.to(torch.int64)
|
||||||
# Todo: enable low latency dispatch
|
reorder_topk_ids = torch.empty(
|
||||||
if True: # not forward_mode.is_decode():
|
(0,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
seg_indptr = torch.zeros(
|
||||||
|
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
masked_m = torch.empty(
|
||||||
|
(self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
expected_m = 0
|
||||||
|
|
||||||
|
if self.deepep_mode == "normal" or (
|
||||||
|
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
||||||
|
):
|
||||||
(
|
(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
num_recv_tokens_per_expert_list,
|
|
||||||
handle,
|
|
||||||
event,
|
event,
|
||||||
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
||||||
self.tokens_per_expert = torch.tensor(
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
num_recv_tokens_per_expert_list,
|
if hidden_states.shape[0] > 0:
|
||||||
device=hidden_states.device,
|
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
||||||
dtype=torch.int64,
|
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, recv_expert_count, handle, event, hook = (
|
|
||||||
self.dispatch_low_latency(
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
num_max_dispatch_tokens_per_rank,
|
|
||||||
num_experts,
|
|
||||||
)
|
)
|
||||||
|
elif self.deepep_mode == "low_latency" or (
|
||||||
|
self.deepep_mode == "auto" and forward_mode.is_decode()
|
||||||
|
):
|
||||||
|
expected_m = (
|
||||||
|
hidden_states.shape[0]
|
||||||
|
* self.buffer_low_latency.group_size
|
||||||
|
* topk_idx.shape[1]
|
||||||
|
+ num_experts
|
||||||
|
) // num_experts
|
||||||
|
hidden_states, masked_m, event, hook = self.dispatch_low_latency(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
num_max_dispatch_tokens_per_rank,
|
||||||
|
num_experts,
|
||||||
|
use_fp8=True,
|
||||||
)
|
)
|
||||||
self.recv_expert_count = recv_expert_count
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
|
|
||||||
if self.async_finish:
|
|
||||||
event.current_stream_wait()
|
|
||||||
|
|
||||||
self.handle = handle
|
|
||||||
self.topk_idx = topk_idx
|
|
||||||
self.topk_weights = topk_weights
|
|
||||||
if hidden_states.shape[0] > 0:
|
|
||||||
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
|
||||||
hidden_states, fp8_dtype=hidden_states.dtype
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
reorder_topk_ids = torch.empty(
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||||
(0,), device=hidden_states.device, dtype=torch.int64
|
|
||||||
)
|
return (
|
||||||
seg_indptr = torch.zeros(
|
hidden_states,
|
||||||
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
topk_idx,
|
||||||
)
|
topk_weights,
|
||||||
return hidden_states, reorder_topk_ids, seg_indptr
|
reorder_topk_ids,
|
||||||
|
seg_indptr,
|
||||||
|
masked_m,
|
||||||
|
expected_m,
|
||||||
|
)
|
||||||
|
|
||||||
def dispatch_normal(
|
def dispatch_normal(
|
||||||
self,
|
self,
|
||||||
@@ -254,12 +263,15 @@ class DeepEPDispatcher:
|
|||||||
allocate_on_comm_stream=previous_event is not None,
|
allocate_on_comm_stream=previous_event is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
|
||||||
|
# However, doing this would incur an unknown synchronization error, but keeping
|
||||||
|
# `handle` as a member variable works.
|
||||||
(
|
(
|
||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
num_recv_tokens_per_expert_list,
|
_, # num_recv_tokens_per_expert_list
|
||||||
handle,
|
self.handle,
|
||||||
event,
|
event,
|
||||||
) = self.buffer_normal.dispatch(
|
) = self.buffer_normal.dispatch(
|
||||||
x,
|
x,
|
||||||
@@ -278,8 +290,6 @@ class DeepEPDispatcher:
|
|||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_idx,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
num_recv_tokens_per_expert_list,
|
|
||||||
handle,
|
|
||||||
event,
|
event,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -289,18 +299,19 @@ class DeepEPDispatcher:
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
num_max_dispatch_tokens_per_rank: int,
|
num_max_dispatch_tokens_per_rank: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
|
use_fp8: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
|
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'.
|
||||||
# Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall!
|
# Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall.
|
||||||
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
||||||
+
|
|
||||||
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
||||||
index f60e933..cddaabf 100644
|
index 76ae2e2..8ecd08f 100644
|
||||||
--- a/csrc/kernels/internode_ll.cu
|
--- a/csrc/kernels/internode_ll.cu
|
||||||
+++ b/csrc/kernels/internode_ll.cu
|
+++ b/csrc/kernels/internode_ll.cu
|
||||||
@@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
@@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8,
|
||||||
void* workspace, cudaStream_t stream, int phases) {
|
void* workspace, cudaStream_t stream, int phases) {
|
||||||
constexpr int kNumMaxTopK = 9;
|
constexpr int kNumMaxTopK = 9;
|
||||||
- constexpr int kNumWarpsPerGroup = 10;
|
- constexpr int kNumWarpsPerGroup = 10;
|
||||||
@@ -308,16 +319,9 @@ class DeepEPDispatcher:
|
|||||||
+ constexpr int kNumWarpsPerGroup = 8;
|
+ constexpr int kNumWarpsPerGroup = 8;
|
||||||
+ constexpr int kNumWarpGroups = 4;
|
+ constexpr int kNumWarpGroups = 4;
|
||||||
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
||||||
+
|
|
||||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
@@ -501,8 +501,8 @@ void combine(void* combined_x,
|
||||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
|
||||||
- EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
|
||||||
+ // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
|
||||||
+
|
|
||||||
// Workspace checks
|
|
||||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
|
||||||
@@ -505,8 +505,8 @@ void combine(void* combined_x,
|
|
||||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||||
int num_topk, int num_experts, int rank, int num_ranks,
|
int num_topk, int num_experts, int rank, int num_ranks,
|
||||||
void* workspace, cudaStream_t stream, int phases) {
|
void* workspace, cudaStream_t stream, int phases) {
|
||||||
@@ -326,28 +330,33 @@ class DeepEPDispatcher:
|
|||||||
+ constexpr int kNumWarpsPerGroup = 8;
|
+ constexpr int kNumWarpsPerGroup = 8;
|
||||||
+ constexpr int kNumWarpGroups = 4;
|
+ constexpr int kNumWarpGroups = 4;
|
||||||
constexpr int kNumMaxTopk = 9;
|
constexpr int kNumMaxTopk = 9;
|
||||||
+
|
|
||||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
recv_hidden_states, recv_expert_count, handle, event, hook = (
|
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
||||||
self.buffer_low_latency.low_latency_dispatch(
|
self.buffer_low_latency.low_latency_dispatch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
num_max_dispatch_tokens_per_rank,
|
num_max_dispatch_tokens_per_rank,
|
||||||
num_experts,
|
num_experts,
|
||||||
async_finish=self.async_finish,
|
use_fp8=use_fp8,
|
||||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
async_finish=not self.return_recv_hook,
|
||||||
|
return_recv_hook=self.return_recv_hook,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# hook()
|
return packed_recv_hidden, packed_recv_count, event, hook
|
||||||
return recv_hidden_states, recv_expert_count, handle, event, hook
|
|
||||||
|
|
||||||
def combine(
|
def combine(
|
||||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
self,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
hidden_states: torch.Tensor,
|
||||||
# Todo: enable low latency combine
|
topk_idx: torch.Tensor,
|
||||||
if True: # not forward_mode.is_decode():
|
topk_weights: torch.Tensor,
|
||||||
|
forward_mode: ForwardMode,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.deepep_mode == "normal" or (
|
||||||
|
self.deepep_mode == "auto" and not forward_mode.is_decode()
|
||||||
|
):
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
@@ -359,8 +368,8 @@ class DeepEPDispatcher:
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
output,
|
output,
|
||||||
self.src2dst,
|
self.src2dst,
|
||||||
self.topk_idx,
|
topk_idx,
|
||||||
self.topk_weights,
|
topk_weights,
|
||||||
self.router_topk,
|
self.router_topk,
|
||||||
hidden_states.shape[1],
|
hidden_states.shape[1],
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
@@ -371,24 +380,30 @@ class DeepEPDispatcher:
|
|||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
hidden_states, event = self.combine_normal(output, self.handle)
|
hidden_states, event = self.combine_normal(
|
||||||
else:
|
output,
|
||||||
hidden_states, event, hook = self.combine_low_latency(
|
|
||||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
|
||||||
)
|
)
|
||||||
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
|
elif self.deepep_mode == "low_latency" or (
|
||||||
|
self.deepep_mode == "auto" and forward_mode.is_decode()
|
||||||
|
):
|
||||||
|
hidden_states, event, hook = self.combine_low_latency(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
)
|
||||||
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||||
|
|
||||||
if self.async_finish:
|
|
||||||
event.current_stream_wait()
|
|
||||||
|
|
||||||
self.handle = None
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def combine_normal(self, x: torch.Tensor, handle: Tuple):
|
def combine_normal(self, x: torch.Tensor):
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
|
|
||||||
combined_x, _, event = self.buffer_normal.combine(
|
combined_x, _, event = self.buffer_normal.combine(
|
||||||
x,
|
x,
|
||||||
handle,
|
self.handle,
|
||||||
async_finish=self.async_finish,
|
async_finish=self.async_finish,
|
||||||
previous_event=previous_event,
|
previous_event=previous_event,
|
||||||
allocate_on_comm_stream=previous_event is not None,
|
allocate_on_comm_stream=previous_event is not None,
|
||||||
@@ -400,17 +415,15 @@ class DeepEPDispatcher:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
handle: Tuple,
|
|
||||||
):
|
):
|
||||||
combined_hidden_states, event_overlap, hook = (
|
combined_hidden_states, event, hook = (
|
||||||
self.buffer_low_latency.low_latency_combine(
|
self.buffer_low_latency.low_latency_combine(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
handle,
|
self.handle,
|
||||||
async_finish=self.async_finish,
|
async_finish=not self.return_recv_hook,
|
||||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
return_recv_hook=self.return_recv_hook,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# hook()
|
return combined_hidden_states, event, hook
|
||||||
return combined_hidden_states, event_overlap, hook
|
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ global_server_args_dict = {
|
|||||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||||
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
||||||
|
"deepep_mode": ServerArgs.deepep_mode,
|
||||||
"device": ServerArgs.device,
|
"device": ServerArgs.device,
|
||||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ class ModelRunner:
|
|||||||
"enable_dp_attention": server_args.enable_dp_attention,
|
"enable_dp_attention": server_args.enable_dp_attention,
|
||||||
"enable_ep_moe": server_args.enable_ep_moe,
|
"enable_ep_moe": server_args.enable_ep_moe,
|
||||||
"enable_deepep_moe": server_args.enable_deepep_moe,
|
"enable_deepep_moe": server_args.enable_deepep_moe,
|
||||||
|
"deepep_mode": server_args.deepep_mode,
|
||||||
"device": server_args.device,
|
"device": server_args.device,
|
||||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||||
@@ -272,7 +273,7 @@ class ModelRunner:
|
|||||||
server_args.disable_radix_cache = True
|
server_args.disable_radix_cache = True
|
||||||
|
|
||||||
if server_args.enable_deepep_moe:
|
if server_args.enable_deepep_moe:
|
||||||
logger.info("DeepEP is turned on.")
|
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
logger.info("Init torch distributed begin.")
|
logger.info("Init torch distributed begin.")
|
||||||
|
|||||||
@@ -188,19 +188,35 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if global_server_args_dict["enable_deepep_moe"]
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||||
)
|
)
|
||||||
self.experts = MoEImpl(
|
if not global_server_args_dict["enable_deepep_moe"]:
|
||||||
num_experts=config.n_routed_experts,
|
self.experts = MoEImpl(
|
||||||
top_k=config.num_experts_per_tok,
|
num_experts=config.n_routed_experts,
|
||||||
hidden_size=config.hidden_size,
|
top_k=config.num_experts_per_tok,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
hidden_size=config.hidden_size,
|
||||||
renormalize=config.norm_topk_prob,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
quant_config=quant_config,
|
renormalize=config.norm_topk_prob,
|
||||||
use_grouped_topk=True,
|
quant_config=quant_config,
|
||||||
num_expert_group=config.n_group,
|
use_grouped_topk=True,
|
||||||
topk_group=config.topk_group,
|
num_expert_group=config.n_group,
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
topk_group=config.topk_group,
|
||||||
prefix=add_prefix("experts", prefix),
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
)
|
prefix=add_prefix("experts", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.experts = MoEImpl(
|
||||||
|
num_experts=config.n_routed_experts,
|
||||||
|
top_k=config.num_experts_per_tok,
|
||||||
|
hidden_size=config.hidden_size,
|
||||||
|
intermediate_size=config.moe_intermediate_size,
|
||||||
|
renormalize=config.norm_topk_prob,
|
||||||
|
quant_config=quant_config,
|
||||||
|
use_grouped_topk=True,
|
||||||
|
num_expert_group=config.n_group,
|
||||||
|
topk_group=config.topk_group,
|
||||||
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
|
prefix=add_prefix("experts", prefix),
|
||||||
|
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||||
|
)
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||||
@@ -227,6 +243,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if global_server_args_dict["enable_deepep_moe"]:
|
if global_server_args_dict["enable_deepep_moe"]:
|
||||||
|
# TODO: we will support tp < ep in the future
|
||||||
|
self.ep_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_experts = config.n_routed_experts
|
self.num_experts = config.n_routed_experts
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
self.renormalize = config.norm_topk_prob
|
self.renormalize = config.norm_topk_prob
|
||||||
@@ -246,7 +264,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
params_dtype=config.torch_dtype,
|
params_dtype=config.torch_dtype,
|
||||||
|
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||||
async_finish=True, # TODO
|
async_finish=True, # TODO
|
||||||
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -301,28 +321,39 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
)
|
)
|
||||||
if self.tp_size > 1:
|
if self.ep_size > 1:
|
||||||
recv_hidden_states, reorder_topk_ids, seg_indptr = (
|
(
|
||||||
self.deepep_dispatcher.dispatch(
|
hidden_states,
|
||||||
hidden_states,
|
topk_idx,
|
||||||
topk_idx,
|
topk_weights,
|
||||||
topk_weights,
|
reorder_topk_ids,
|
||||||
self.num_experts,
|
seg_indptr,
|
||||||
forward_mode,
|
masked_m,
|
||||||
)
|
expected_m,
|
||||||
|
) = self.deepep_dispatcher.dispatch(
|
||||||
|
hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
self.num_experts,
|
||||||
|
forward_mode=forward_mode,
|
||||||
)
|
)
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
self.experts(
|
self.experts(
|
||||||
hidden_states=recv_hidden_states,
|
hidden_states=hidden_states,
|
||||||
reorder_topk_ids=reorder_topk_ids,
|
reorder_topk_ids=reorder_topk_ids,
|
||||||
seg_indptr=seg_indptr,
|
seg_indptr=seg_indptr,
|
||||||
|
masked_m=masked_m,
|
||||||
|
expected_m=expected_m,
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
)
|
)
|
||||||
* self.routed_scaling_factor
|
* self.routed_scaling_factor
|
||||||
)
|
)
|
||||||
if self.tp_size > 1:
|
if self.ep_size > 1:
|
||||||
final_hidden_states = self.deepep_dispatcher.combine(
|
final_hidden_states = self.deepep_dispatcher.combine(
|
||||||
final_hidden_states, forward_mode
|
final_hidden_states,
|
||||||
|
topk_idx,
|
||||||
|
topk_weights,
|
||||||
|
forward_mode,
|
||||||
)
|
)
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
|||||||
@@ -161,6 +161,7 @@ class ServerArgs:
|
|||||||
enable_dp_attention: bool = False
|
enable_dp_attention: bool = False
|
||||||
enable_ep_moe: bool = False
|
enable_ep_moe: bool = False
|
||||||
enable_deepep_moe: bool = False
|
enable_deepep_moe: bool = False
|
||||||
|
deepep_mode: Optional[str] = "auto"
|
||||||
enable_torch_compile: bool = False
|
enable_torch_compile: bool = False
|
||||||
torch_compile_max_bs: int = 32
|
torch_compile_max_bs: int = 32
|
||||||
cuda_graph_max_bs: Optional[int] = None
|
cuda_graph_max_bs: Optional[int] = None
|
||||||
@@ -285,6 +286,13 @@ class ServerArgs:
|
|||||||
if self.grammar_backend is None:
|
if self.grammar_backend is None:
|
||||||
self.grammar_backend = "xgrammar"
|
self.grammar_backend = "xgrammar"
|
||||||
|
|
||||||
|
# Expert parallelism
|
||||||
|
if self.enable_ep_moe:
|
||||||
|
self.ep_size = self.tp_size
|
||||||
|
logger.info(
|
||||||
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||||
|
)
|
||||||
|
|
||||||
# Data parallelism attention
|
# Data parallelism attention
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||||
@@ -300,6 +308,10 @@ class ServerArgs:
|
|||||||
self.enable_sp_layernorm = False
|
self.enable_sp_layernorm = False
|
||||||
# DeepEP MoE
|
# DeepEP MoE
|
||||||
if self.enable_deepep_moe:
|
if self.enable_deepep_moe:
|
||||||
|
if self.deepep_mode == "auto":
|
||||||
|
assert (
|
||||||
|
not self.enable_dp_attention
|
||||||
|
), "DeepEP MoE `auto` mode is not supported with DP Attention."
|
||||||
self.ep_size = self.tp_size
|
self.ep_size = self.tp_size
|
||||||
self.enable_sp_layernorm = (
|
self.enable_sp_layernorm = (
|
||||||
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
self.dp_size < self.tp_size if self.enable_dp_attention else True
|
||||||
@@ -1082,6 +1094,12 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enabling DeepEP MoE implementation for EP MoE.",
|
help="Enabling DeepEP MoE implementation for EP MoE.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--deepep-mode",
|
||||||
|
type=str,
|
||||||
|
choices=["normal", "low_latency", "auto"],
|
||||||
|
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
||||||
|
)
|
||||||
|
|
||||||
# Server warmups
|
# Server warmups
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user