[Feature] Integrate DeepEP into SGLang (#4232)
Co-authored-by: Cheng Wan <cwan39@gatech.edu> Co-authored-by: Xuting Zhou <xutingz@nvidia.com>
This commit is contained in:
@@ -687,10 +687,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
):
|
||||
if loaded_shard_id is None:
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||
param.load_merged_column_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
shard_id=0,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
return
|
||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||
param.load_merged_column_weight(
|
||||
loaded_weight=loaded_weight,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
return
|
||||
# TODO: @dsikka - move to parameter.py
|
||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||
@@ -719,6 +728,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
@@ -16,6 +17,117 @@ if _is_cuda:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_src2dst_triton_kernel(
|
||||
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = dst_id < num_toks
|
||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||
tl.store(src2dst + src_id, dst_id, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def deepep_compute_src2dst_triton_kernel(
|
||||
reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = dst_id < num_toks
|
||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||
num_invalid = tl.load(num_minus_one)
|
||||
tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
|
||||
|
||||
|
||||
def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
||||
|
||||
# Find offet
|
||||
expert_ids = torch.arange(
|
||||
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
|
||||
)
|
||||
torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
|
||||
num_minus_one = seg_indptr[0]
|
||||
seg_indptr = seg_indptr - num_minus_one
|
||||
|
||||
BLOCK_SIZE = 512
|
||||
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
||||
deepep_compute_src2dst_triton_kernel[grid](
|
||||
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
|
||||
)
|
||||
|
||||
reorder_topk_ids = reorder_topk_ids[num_minus_one:]
|
||||
return reorder_topk_ids, src2dst, seg_indptr
|
||||
|
||||
|
||||
@triton.jit
|
||||
def deepep_permute_triton_kernel(
|
||||
input_ptr,
|
||||
gateup_input_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
a1_scales_ptr,
|
||||
topk,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||
|
||||
src_idx = tl.program_id(0)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
|
||||
src_ptr = input_ptr + src_idx * hidden_size
|
||||
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < hidden_size
|
||||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
||||
|
||||
for idx in range(topk):
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
if dst_idx >= 0:
|
||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||
out_data = (in_data).to(OutDtype)
|
||||
tl.store(dst_ptr + offset, out_data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def deepep_post_reorder_triton_kernel(
|
||||
down_output_ptr,
|
||||
output_ptr,
|
||||
src2dst_ptr,
|
||||
topk_ids_ptr,
|
||||
topk_weights_ptr,
|
||||
topk,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = down_output_ptr.dtype.element_ty
|
||||
|
||||
src_idx = tl.program_id(0)
|
||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
||||
|
||||
store_ptr = output_ptr + src_idx * hidden_size
|
||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < hidden_size
|
||||
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
||||
for idx in range(topk):
|
||||
dst_idx = tl.load(src2dst_ptr + idx)
|
||||
if dst_idx >= 0:
|
||||
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
||||
load_ptr = down_output_ptr + dst_idx * hidden_size
|
||||
in_data = tl.load(load_ptr + offset, mask=mask)
|
||||
sum_vec += in_data * weigh_scale
|
||||
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
||||
expert = tl.program_id(0)
|
||||
@@ -33,17 +145,6 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
||||
tl.store(seg_indptr + expert + 1, target_location + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def compute_src2dst_triton_kernel(
|
||||
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = dst_id < num_toks
|
||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||
tl.store(src2dst + src_id, dst_id, mask=mask)
|
||||
|
||||
|
||||
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||
|
||||
@@ -2,6 +2,13 @@ import logging
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
# TODO: use deep_gemm masked kernel after low latency dispatch
|
||||
# import deep_gemm
|
||||
# from deep_gemm import (
|
||||
# get_col_major_tma_aligned_tensor,
|
||||
# m_grouped_gemm_fp8_fp8_bf16_nt_masked,
|
||||
# )
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
@@ -25,6 +32,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -39,6 +47,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
_buffer = None
|
||||
|
||||
|
||||
class GroupedGemmRunner(torch.nn.Module):
|
||||
flashinfer_gemm_warpper = None
|
||||
@@ -773,3 +783,267 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DeepEPMoE(EPMoE):
|
||||
"""
|
||||
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
|
||||
"""
|
||||
|
||||
_has_printed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
prefix,
|
||||
correction_bias,
|
||||
custom_routing_function,
|
||||
activation,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tokens_per_expert: torch.Tensor,
|
||||
forward_mode: ForwardMode,
|
||||
):
|
||||
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
||||
if True: # not forward_mode.is_decode():
|
||||
return self.forward_normal(hidden_states, tokens_per_expert)
|
||||
else:
|
||||
return self.forward_deepgemm_masked(hidden_states, tokens_per_expert)
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
tokens_per_expert: torch.Tensor,
|
||||
):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
if self.grouped_gemm_runner is None:
|
||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
||||
)
|
||||
seg_indptr_cur_rank = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype
|
||||
),
|
||||
torch.cumsum(tokens_per_expert, dim=0),
|
||||
]
|
||||
)
|
||||
reorder_topk_ids = torch.repeat_interleave(tokens_per_expert)
|
||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
weight_indices_cur_rank = torch.arange(
|
||||
0,
|
||||
self.num_experts_per_partition,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if hidden_states.shape[0] > 0:
|
||||
gateup_output = self.grouped_gemm_runner(
|
||||
a=hidden_states,
|
||||
b=self.w13_weight,
|
||||
c=gateup_output,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
scale_a=self.w13_input_scale,
|
||||
scale_b=(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
0,
|
||||
self.num_experts_per_partition - 1,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if down_input.shape[0] > 0:
|
||||
down_output = self.grouped_gemm_runner(
|
||||
a=down_input,
|
||||
b=self.w2_weight,
|
||||
c=down_output,
|
||||
batch_size=self.num_experts_per_partition,
|
||||
weight_column_major=True,
|
||||
seg_indptr=seg_indptr_cur_rank,
|
||||
weight_indices=weight_indices_cur_rank,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
scale_a=self.w2_input_scale,
|
||||
scale_b=(
|
||||
self.w2_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w2_weight_scale
|
||||
),
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
return down_output
|
||||
|
||||
def forward_deepgemm_masked(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reorder_topk_ids: torch.Tensor,
|
||||
seg_indptr: torch.Tensor,
|
||||
):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
|
||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||
max_value = (
|
||||
torch.max(hidden_states)
|
||||
.repeat(self.num_experts_per_partition)
|
||||
.to(torch.float32)
|
||||
)
|
||||
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
||||
|
||||
# GroupGemm-0
|
||||
gateup_output = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.w13_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if hidden_states.shape[0] > 0:
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
hidden_states = (
|
||||
hidden_states[0],
|
||||
get_col_major_tma_aligned_tensor(hidden_states[1]),
|
||||
)
|
||||
"""
|
||||
gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
hidden_states, self.w13_weight, out, masked_m, expected_m
|
||||
)
|
||||
"""
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
gateup_output.shape[0],
|
||||
gateup_output.shape[1] // 2,
|
||||
device=gateup_output.device,
|
||||
dtype=(
|
||||
self.fp8_dtype
|
||||
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
if self.w2_input_scale is None and not self.use_block_quant:
|
||||
self.w2_input_scale = torch.ones(
|
||||
self.num_experts_per_partition,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
0,
|
||||
self.num_experts_per_partition - 1,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
down_input.shape[0],
|
||||
self.w2_weight.shape[1],
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
if down_input.shape[0] > 0:
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
down_input = (
|
||||
down_input[0],
|
||||
get_col_major_tma_aligned_tensor(down_input[1]),
|
||||
)
|
||||
"""
|
||||
down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
||||
down_input, self.w2_weight, out, masked_m, expected_m
|
||||
)
|
||||
"""
|
||||
|
||||
return down_output
|
||||
|
||||
533
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
Normal file
533
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
Normal file
@@ -0,0 +1,533 @@
|
||||
try:
|
||||
from deep_ep import Buffer
|
||||
|
||||
use_deepep = True
|
||||
except ImportError:
|
||||
use_deepep = False
|
||||
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
compute_src2dst_triton_kernel,
|
||||
deepep_permute_triton_kernel,
|
||||
deepep_post_reorder_triton_kernel,
|
||||
deepep_run_moe_deep_preprocess,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
|
||||
_buffer_normal = None
|
||||
_buffer_low_latency = None
|
||||
|
||||
|
||||
def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
||||
"""
|
||||
Copy from DeepEP example usage in model inference prefilling.
|
||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
||||
"""
|
||||
|
||||
global _buffer_normal
|
||||
|
||||
num_nvl_bytes, num_rdma_bytes = 0, 0
|
||||
for config in (
|
||||
Buffer.get_dispatch_config(group.size()),
|
||||
Buffer.get_combine_config(group.size()),
|
||||
):
|
||||
num_nvl_bytes = max(
|
||||
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
|
||||
)
|
||||
num_rdma_bytes = max(
|
||||
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
|
||||
)
|
||||
|
||||
if (
|
||||
_buffer_normal is None
|
||||
or _buffer_normal.group != group
|
||||
or _buffer_normal.num_nvl_bytes < num_nvl_bytes
|
||||
or _buffer_normal.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
_buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
|
||||
return _buffer_normal
|
||||
|
||||
|
||||
def get_buffer_low_latency(
|
||||
group: dist.ProcessGroup,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
hidden: int,
|
||||
num_experts: int,
|
||||
):
|
||||
"""
|
||||
Copy from DeepEP example usage in model inference decoding.
|
||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
"""
|
||||
|
||||
global _buffer_low_latency
|
||||
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
|
||||
)
|
||||
|
||||
if (
|
||||
_buffer_low_latency is None
|
||||
or _buffer_low_latency.group != group
|
||||
or not _buffer_low_latency.low_latency_mode
|
||||
or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
assert num_experts % group.size() == 0
|
||||
_buffer_low_latency = Buffer(
|
||||
group,
|
||||
0,
|
||||
num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_experts // group.size(),
|
||||
)
|
||||
return _buffer_low_latency
|
||||
|
||||
|
||||
def permute(
|
||||
tokens,
|
||||
routing_map,
|
||||
num_out_tokens: Optional[int] = None,
|
||||
fused: bool = False,
|
||||
drop_and_pad: bool = False,
|
||||
):
|
||||
"""
|
||||
Copy from Megatron-Core moe for token permutation
|
||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py
|
||||
"""
|
||||
|
||||
num_tokens, _ = tokens.shape
|
||||
num_experts = routing_map.shape[1]
|
||||
if drop_and_pad and not (num_out_tokens is None):
|
||||
capacity = num_out_tokens // num_experts
|
||||
assert not routing_map.requires_grad
|
||||
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
|
||||
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
|
||||
:, :capacity
|
||||
].contiguous()
|
||||
sorted_indices = sorted_indices.view(-1)
|
||||
else:
|
||||
routing_map = routing_map.bool().T.contiguous()
|
||||
token_indices = (
|
||||
torch.arange(num_tokens, device=routing_map.device)
|
||||
.unsqueeze(0)
|
||||
.expand(num_experts, -1)
|
||||
)
|
||||
sorted_indices = token_indices.masked_select(routing_map)
|
||||
permuted_input = tokens.index_select(0, sorted_indices)
|
||||
|
||||
return permuted_input, sorted_indices
|
||||
|
||||
|
||||
def unpermute(
|
||||
permuted_tokens: torch.Tensor,
|
||||
sorted_indices: torch.Tensor,
|
||||
restore_shape: torch.Size,
|
||||
probs: torch.Tensor = None,
|
||||
routing_map: torch.Tensor = None,
|
||||
fused: bool = False,
|
||||
drop_and_pad: bool = False,
|
||||
):
|
||||
"""
|
||||
Copy from Megatron-Core moe for token unpermutation
|
||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py
|
||||
"""
|
||||
|
||||
_, hidden = restore_shape
|
||||
|
||||
if probs is not None:
|
||||
assert routing_map is not None, "Mask must be provided to permute the probs."
|
||||
if drop_and_pad:
|
||||
num_experts = routing_map.size(1)
|
||||
num_permuted_tokens = sorted_indices.size(0)
|
||||
capacity = num_permuted_tokens // num_experts
|
||||
num_unpermuted_tokens = probs.size(0)
|
||||
|
||||
probs_T_1D = probs.T.contiguous().view(-1)
|
||||
|
||||
indices_dim0 = torch.arange(
|
||||
num_experts, device=routing_map.device
|
||||
).unsqueeze(-1)
|
||||
indices_dim1 = sorted_indices.view(num_experts, capacity)
|
||||
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
|
||||
|
||||
permuted_probs = probs_T_1D.index_select(0, indices_1D)
|
||||
else:
|
||||
permuted_probs = probs.T.contiguous().masked_select(
|
||||
routing_map.T.contiguous()
|
||||
)
|
||||
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
|
||||
|
||||
output_tokens = torch.zeros(
|
||||
restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype
|
||||
)
|
||||
output_tokens.scatter_add_(
|
||||
0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens
|
||||
)
|
||||
|
||||
return output_tokens
|
||||
|
||||
|
||||
class DeepEPDispatcher:
|
||||
"""
|
||||
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group: torch.distributed.ProcessGroup,
|
||||
router_topk: int,
|
||||
permute_fusion: bool = False,
|
||||
capacity_factor: float = None,
|
||||
num_experts: int = None,
|
||||
num_local_experts: int = None,
|
||||
hidden_size: int = None,
|
||||
params_dtype: torch.dtype = None,
|
||||
):
|
||||
self.group = group
|
||||
self.router_topk = router_topk
|
||||
self.capacity_factor = capacity_factor
|
||||
self.permute_fusion = permute_fusion
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_local_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.recv_expert_count = None
|
||||
self.params_dtype = params_dtype
|
||||
self.params_bytes = 2
|
||||
# Metadata
|
||||
self.token_indices = None
|
||||
self.token_probs = None
|
||||
# Handle used for combine operation
|
||||
self.handle = None
|
||||
|
||||
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
||||
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
self.num_max_dispatch_tokens_per_rank = 128
|
||||
|
||||
if not use_deepep:
|
||||
raise ImportError(
|
||||
"DeepEP is not installed. Please install DeepEP package from "
|
||||
"https://github.com/deepseek-ai/deepep."
|
||||
)
|
||||
self.buffer_normal = get_buffer_normal(
|
||||
self.group, self.hidden_size * self.params_bytes
|
||||
)
|
||||
self.buffer_low_latency = None
|
||||
# Todo: enable low latency dispatch
|
||||
"""
|
||||
self.buffer_low_latency = get_buffer_low_latency(
|
||||
self.group,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.hidden_size * self.params_bytes,
|
||||
self.num_experts,
|
||||
)
|
||||
"""
|
||||
|
||||
def deepep_permute(
|
||||
self,
|
||||
topk_ids,
|
||||
hidden_states,
|
||||
num_experts,
|
||||
top_k,
|
||||
use_fp8_w8a8,
|
||||
use_block_quant,
|
||||
fp8_dtype,
|
||||
):
|
||||
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||
topk_ids, num_experts
|
||||
)
|
||||
num_total_tokens = reorder_topk_ids.numel()
|
||||
gateup_input = torch.empty(
|
||||
(int(num_total_tokens), hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=(
|
||||
fp8_dtype
|
||||
if (use_fp8_w8a8 and not use_block_quant)
|
||||
else hidden_states.dtype
|
||||
),
|
||||
)
|
||||
# PreReorder
|
||||
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
||||
hidden_states,
|
||||
gateup_input,
|
||||
src2dst,
|
||||
topk_ids,
|
||||
None,
|
||||
top_k,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
self.src2dst = src2dst
|
||||
return reorder_topk_ids, seg_indptr, gateup_input
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
forward_mode: ForwardMode,
|
||||
previous_event=None,
|
||||
num_max_dispatch_tokens_per_rank: int = 128,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self.hidden_shape = hidden_states.shape
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
# Todo: enable low latency dispatch
|
||||
if True: # not forward_mode.is_decode():
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = self.dispatch_normal(
|
||||
hidden_states, topk_idx, topk_weights, num_experts, previous_event
|
||||
)
|
||||
self.tokens_per_expert = torch.tensor(
|
||||
num_recv_tokens_per_expert_list,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
else:
|
||||
hidden_states, recv_expert_count, handle, event, hook = (
|
||||
self.dispatch_low_latency(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
)
|
||||
)
|
||||
self.recv_expert_count = recv_expert_count
|
||||
tokens_per_expert = self.get_number_of_tokens_per_expert()
|
||||
self.handle = handle
|
||||
self.topk_idx = topk_idx
|
||||
self.topk_weights = topk_weights
|
||||
if hidden_states.shape[0] > 0:
|
||||
hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states)
|
||||
return hidden_states, topk_idx, topk_weights, tokens_per_expert
|
||||
|
||||
def dispatch_normal(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
previous_event=None,
|
||||
):
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
previous_event,
|
||||
) = self.buffer_normal.get_dispatch_layout(
|
||||
topk_idx,
|
||||
num_experts,
|
||||
previous_event=previous_event,
|
||||
async_finish=False,
|
||||
allocate_on_comm_stream=False,
|
||||
)
|
||||
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = self.buffer_normal.dispatch(
|
||||
x,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
num_tokens_per_rank=num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
is_token_in_rank=is_token_in_rank,
|
||||
num_tokens_per_expert=num_tokens_per_expert,
|
||||
previous_event=previous_event,
|
||||
async_finish=False,
|
||||
allocate_on_comm_stream=False,
|
||||
)
|
||||
|
||||
return (
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
)
|
||||
|
||||
def dispatch_low_latency(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
num_max_dispatch_tokens_per_rank: int,
|
||||
num_experts: int,
|
||||
):
|
||||
"""
|
||||
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
|
||||
# Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall!
|
||||
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
||||
+
|
||||
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
||||
index f60e933..cddaabf 100644
|
||||
--- a/csrc/kernels/internode_ll.cu
|
||||
+++ b/csrc/kernels/internode_ll.cu
|
||||
@@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
constexpr int kNumMaxTopK = 9;
|
||||
- constexpr int kNumWarpsPerGroup = 10;
|
||||
- constexpr int kNumWarpGroups = 3;
|
||||
+ constexpr int kNumWarpsPerGroup = 8;
|
||||
+ constexpr int kNumWarpGroups = 4;
|
||||
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
||||
+
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
||||
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
||||
- EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
||||
+ // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
||||
+
|
||||
// Workspace checks
|
||||
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
||||
@@ -505,8 +505,8 @@ void combine(void* combined_x,
|
||||
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
||||
int num_topk, int num_experts, int rank, int num_ranks,
|
||||
void* workspace, cudaStream_t stream, int phases) {
|
||||
- constexpr int kNumWarpsPerGroup = 10;
|
||||
- constexpr int kNumWarpGroups = 3;
|
||||
+ constexpr int kNumWarpsPerGroup = 8;
|
||||
+ constexpr int kNumWarpGroups = 4;
|
||||
constexpr int kNumMaxTopk = 9;
|
||||
+
|
||||
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
||||
"""
|
||||
|
||||
recv_hidden_states, recv_expert_count, handle, event, hook = (
|
||||
self.buffer_low_latency.low_latency_dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
num_max_dispatch_tokens_per_rank,
|
||||
num_experts,
|
||||
async_finish=False,
|
||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
||||
)
|
||||
)
|
||||
# hook()
|
||||
return recv_hidden_states, recv_expert_count, handle, event, hook
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Todo: enable low latency combine
|
||||
if True: # not forward_mode.is_decode():
|
||||
if hidden_states.shape[0] > 0:
|
||||
hidden_states = self.get_restored_hidden_states_by_experts(
|
||||
hidden_states
|
||||
)
|
||||
hidden_states, event = self.combine_normal(hidden_states, self.handle)
|
||||
else:
|
||||
hidden_states, event, hook = self.combine_low_latency(
|
||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
||||
)
|
||||
self.handle = None
|
||||
return hidden_states.view(self.hidden_shape)
|
||||
|
||||
def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None):
|
||||
combined_x, _, event = self.buffer_normal.combine(
|
||||
x,
|
||||
handle,
|
||||
async_finish=False,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=False,
|
||||
)
|
||||
return combined_x, event
|
||||
|
||||
def combine_low_latency(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
handle: Tuple,
|
||||
):
|
||||
combined_hidden_states, event_overlap, hook = (
|
||||
self.buffer_low_latency.low_latency_combine(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
||||
)
|
||||
)
|
||||
# hook()
|
||||
return combined_hidden_states, event_overlap, hook
|
||||
|
||||
def _indices_to_multihot(self, indices, probs):
|
||||
batch_size = indices.shape[0]
|
||||
multihot_routing_map = torch.zeros(
|
||||
(batch_size, self.num_local_experts),
|
||||
dtype=torch.long,
|
||||
device=indices.device,
|
||||
)
|
||||
|
||||
multihot_probs = torch.zeros(
|
||||
(batch_size, self.num_local_experts),
|
||||
dtype=torch.float,
|
||||
device=indices.device,
|
||||
)
|
||||
|
||||
mask = indices != -1
|
||||
valid_indices = indices[mask]
|
||||
row_indices = torch.arange(batch_size, device=indices.device).repeat_interleave(
|
||||
mask.sum(dim=1)
|
||||
)
|
||||
multihot_routing_map[row_indices, valid_indices] = 1
|
||||
multihot_probs[row_indices, valid_indices] = probs[mask]
|
||||
return multihot_routing_map.bool(), multihot_probs
|
||||
|
||||
def get_dispached_metadata(self) -> torch.Tensor:
|
||||
return self.topk_idx, self.topk_weights
|
||||
|
||||
def get_number_of_tokens_per_expert(self) -> torch.Tensor:
|
||||
"""
|
||||
Get the number of tokens per expert.
|
||||
"""
|
||||
return self.tokens_per_expert
|
||||
|
||||
def get_permuted_hidden_states_by_experts(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
self.dispatched_routing_map, self.topk_weights = self._indices_to_multihot(
|
||||
self.topk_idx, self.topk_weights
|
||||
)
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
hidden_states, self.reversed_mapping_for_combine = permute(
|
||||
hidden_states,
|
||||
self.dispatched_routing_map,
|
||||
num_out_tokens=self.tokens_per_expert.sum(),
|
||||
fused=self.permute_fusion,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def get_restored_hidden_states_by_experts(
|
||||
self, hidden_states: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
assert (
|
||||
self.topk_weights.dtype == torch.float32
|
||||
), "DeepEP only supports float32 probs"
|
||||
hidden_states = unpermute(
|
||||
hidden_states,
|
||||
self.reversed_mapping_for_combine,
|
||||
restore_shape=self.hidden_shape_before_permute,
|
||||
routing_map=self.dispatched_routing_map,
|
||||
probs=self.topk_weights,
|
||||
fused=self.permute_fusion,
|
||||
)
|
||||
return hidden_states.to(input_dtype)
|
||||
@@ -105,6 +105,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
tp_rank = kwargs.get("tp_rank")
|
||||
use_presharded_weights = kwargs.get("use_presharded_weights")
|
||||
if (
|
||||
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||
@@ -116,7 +117,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
|
||||
param_data = self.data
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
|
||||
@@ -67,6 +67,7 @@ global_server_args_dict = {
|
||||
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
|
||||
@@ -145,6 +145,7 @@ class ModelRunner:
|
||||
"enable_nan_detection": server_args.enable_nan_detection,
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"enable_deepep_moe": server_args.enable_deepep_moe,
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
@@ -277,6 +278,12 @@ class ModelRunner:
|
||||
server_args.chunked_prefill_size = -1
|
||||
server_args.disable_radix_cache = True
|
||||
|
||||
if server_args.enable_deepep_moe:
|
||||
logger.info("DeepEP is turned on.")
|
||||
assert (
|
||||
server_args.enable_dp_attention == True
|
||||
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
|
||||
165
python/sglang/srt/models/deepseek_v2.py
Executable file → Normal file
165
python/sglang/srt/models/deepseek_v2.py
Executable file → Normal file
@@ -26,6 +26,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
parallel_state,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
@@ -47,8 +48,10 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
block_quant_to_tensor_quant,
|
||||
@@ -65,7 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
||||
|
||||
@@ -87,6 +90,8 @@ class DeepseekV2MLP(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -95,6 +100,8 @@ class DeepseekV2MLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -103,6 +110,8 @@ class DeepseekV2MLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -167,7 +176,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
||||
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
MoEImpl = (
|
||||
DeepEPMoE
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
@@ -184,16 +197,59 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
# disable tp for shared experts when enable deepep moe
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
)
|
||||
else:
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
tp_rank=0,
|
||||
tp_size=1,
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
self.num_experts = config.n_routed_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.renormalize = config.norm_topk_prob
|
||||
self.topk_group = config.topk_group
|
||||
self.num_expert_group = config.n_group
|
||||
self.correction_bias = (
|
||||
self.gate.e_score_correction_bias.data
|
||||
if self.gate.e_score_correction_bias is not None
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
self.deepep_dispatcher = DeepEPDispatcher(
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=config.n_routed_experts,
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
||||
) -> torch.Tensor:
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
return self.forward_normal(hidden_states)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_mode)
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
if self.n_shared_experts is not None:
|
||||
@@ -208,6 +264,59 @@ class DeepseekV2MoE(nn.Module):
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
def forward_deepep(
|
||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
topk_idx = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
if forward_mode is not None and not forward_mode.is_idle():
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
topk_weights, topk_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = (
|
||||
self.deepep_dispatcher.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.num_experts,
|
||||
forward_mode,
|
||||
)
|
||||
)
|
||||
final_hidden_states = (
|
||||
self.experts(
|
||||
hidden_states=recv_hidden_states,
|
||||
tokens_per_expert=tokens_per_expert,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
* self.routed_scaling_factor
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
final_hidden_states, forward_mode
|
||||
)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
@@ -959,15 +1068,25 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.dp_size != 1:
|
||||
if get_attention_tp_rank() == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
if global_server_args_dict["enable_deepep_moe"] and isinstance(
|
||||
self.mlp, DeepseekV2MoE
|
||||
):
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
return hidden_states, residual
|
||||
else:
|
||||
if get_attention_tp_rank() == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
@@ -1099,7 +1218,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
MoEImpl = (
|
||||
DeepEPMoE
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
|
||||
@@ -157,6 +157,7 @@ class ServerArgs:
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
@@ -281,6 +282,12 @@ class ServerArgs:
|
||||
logger.warning(
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||
)
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
self.ep_size = self.dp_size
|
||||
logger.info(
|
||||
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]."
|
||||
)
|
||||
|
||||
# Speculative Decoding
|
||||
if self.speculative_algorithm == "NEXTN":
|
||||
@@ -1018,6 +1025,11 @@ class ServerArgs:
|
||||
default=ServerArgs.hicache_ratio,
|
||||
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-deepep-moe",
|
||||
action="store_true",
|
||||
help="Enabling DeepEP MoE implementation for EP MoE.",
|
||||
)
|
||||
|
||||
# Server warmups
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user