[9/N] MoE Refactor: cleanup dispatcher interfaces (#11847)
This commit is contained in:
@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper:
|
||||
_global_dp_buffer_len: int
|
||||
_local_dp_buffer_len: int
|
||||
_global_num_tokens: Optional[List[int]]
|
||||
_is_extend_in_batch: bool
|
||||
|
||||
@classmethod
|
||||
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
||||
@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper:
|
||||
def get_dp_device(cls) -> torch.device:
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
|
||||
cls._is_extend_in_batch = is_extend_in_batch
|
||||
|
||||
@classmethod
|
||||
def get_is_extend_in_batch(cls) -> bool:
|
||||
return cls._is_extend_in_batch
|
||||
|
||||
|
||||
def set_dp_buffer_len(
|
||||
global_dp_buffer_len: int,
|
||||
@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device:
|
||||
return _DpGatheredBufferWrapper.get_dp_device()
|
||||
|
||||
|
||||
def set_is_extend_in_batch(is_extend_in_batch: bool):
|
||||
_DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch)
|
||||
|
||||
|
||||
def get_is_extend_in_batch() -> bool:
|
||||
return _DpGatheredBufferWrapper.get_is_extend_in_batch()
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
|
||||
@@ -566,7 +566,9 @@ def ep_scatter(
|
||||
scale_hidden_size = ceil_div(scale_hidden_size, 4)
|
||||
|
||||
assert m_indices.shape[0] % BLOCK_E == 0
|
||||
assert recv_x_scale.dtype == output_tensor_scale.dtype
|
||||
assert (
|
||||
recv_x_scale.dtype == output_tensor_scale.dtype
|
||||
), f"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}"
|
||||
assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
|
||||
|
||||
_fwd_kernel_ep_scatter_1[(grid,)](
|
||||
|
||||
@@ -20,18 +20,14 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
CUTEDSL_MOE_NVFP4_DISPATCH,
|
||||
ModelOptNvFp4FusedMoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||
from sglang.srt.utils.offloader import get_offloader
|
||||
@@ -109,23 +105,6 @@ class DeepEPMoE(FusedMoE):
|
||||
|
||||
self.deepep_mode = get_deepep_mode()
|
||||
|
||||
# TODO: move to the beginning of the file
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=self.num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
params_dtype=params_dtype,
|
||||
deepep_mode=self.deepep_mode,
|
||||
async_finish=True, # TODO
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||
# NPU supports low_latency deepep without deepgemm
|
||||
assert (
|
||||
@@ -165,19 +144,16 @@ class DeepEPMoE(FusedMoE):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
topk_output: TopKOutput,
|
||||
forward_shared_experts=None,
|
||||
alt_stream=None,
|
||||
disable_sbo=False,
|
||||
):
|
||||
|
||||
# We have to call SBO inside MoE to be compatible with hooks used in offloading
|
||||
return single_batch_overlap.execute_sbo(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
topk_output=topk_output,
|
||||
# SBO args
|
||||
experts=self,
|
||||
forward_shared_experts=forward_shared_experts,
|
||||
@@ -188,25 +164,14 @@ class DeepEPMoE(FusedMoE):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
return self.deepep_dispatcher.dispatch(
|
||||
return self.dispatcher.dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
input_global_scale=(
|
||||
self.w13_input_scale_quant
|
||||
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
||||
and self.quant_method.enable_flashinfer_cutedsl_moe
|
||||
and CUTEDSL_MOE_NVFP4_DISPATCH
|
||||
else None
|
||||
),
|
||||
topk_output=topk_output,
|
||||
)
|
||||
|
||||
def moe_impl(
|
||||
def run_moe_core(
|
||||
self,
|
||||
dispatch_output: DispatchOutput,
|
||||
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
||||
@@ -240,16 +205,14 @@ class DeepEPMoE(FusedMoE):
|
||||
def combine(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
overlap_args: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
return self.deepep_dispatcher.combine(
|
||||
return self.dispatcher.combine(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
|
||||
@@ -257,9 +220,9 @@ class DeepEPMoE(FusedMoE):
|
||||
self,
|
||||
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
||||
):
|
||||
hidden_states, topk_idx, topk_weights = (
|
||||
hidden_states, topk_ids, topk_weights = (
|
||||
dispatch_output.hidden_states,
|
||||
dispatch_output.topk_idx,
|
||||
dispatch_output.topk_ids,
|
||||
dispatch_output.topk_weights,
|
||||
)
|
||||
if hidden_states.shape[0] == 0:
|
||||
@@ -267,15 +230,15 @@ class DeepEPMoE(FusedMoE):
|
||||
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
||||
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
||||
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
||||
topk_idx_copy = topk_idx.to(torch.int32)
|
||||
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
|
||||
topk_ids_copy = topk_ids.to(torch.int32)
|
||||
topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts
|
||||
|
||||
return fused_moe(
|
||||
hidden_states,
|
||||
self.w13_weight,
|
||||
self.w2_weight,
|
||||
topk_weights,
|
||||
topk_idx_copy,
|
||||
topk_ids_copy,
|
||||
w1_scale=self.w13_weight_scale_inv,
|
||||
w2_scale=self.w2_weight_scale_inv,
|
||||
quant_type=QuantType.per_128x128,
|
||||
@@ -291,18 +254,21 @@ class DeepEPMoE(FusedMoE):
|
||||
self,
|
||||
dispatch_output: DeepEPNormalOutput,
|
||||
):
|
||||
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
||||
dispatch_output
|
||||
)
|
||||
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
||||
(
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert,
|
||||
) = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
if num_recv_tokens_per_expert is None:
|
||||
return hidden_states_fp8.bfloat16()
|
||||
return hidden_states.bfloat16()
|
||||
all_tokens = sum(num_recv_tokens_per_expert)
|
||||
if all_tokens <= 0:
|
||||
return hidden_states_fp8.bfloat16()
|
||||
M, K = hidden_states_fp8.size()
|
||||
return hidden_states.bfloat16()
|
||||
M, K = hidden_states.size()
|
||||
N = self.w13_weight.size(1)
|
||||
scale_block_size = 128
|
||||
|
||||
@@ -323,35 +289,35 @@ class DeepEPMoE(FusedMoE):
|
||||
),
|
||||
)
|
||||
|
||||
hidden_states_fp8_shape = hidden_states_fp8.shape
|
||||
hidden_states_fp8_device = hidden_states_fp8.device
|
||||
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_device = hidden_states.device
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
|
||||
input_tensor = [
|
||||
torch.empty(
|
||||
(all_tokens, K),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=hidden_states_fp8.dtype,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
),
|
||||
(
|
||||
# TODO check whether need `zeros`
|
||||
torch.zeros(
|
||||
(ceil_div(K // 128, 4), all_tokens),
|
||||
device=hidden_states_fp8.device,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.int,
|
||||
).transpose(0, 1)
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else torch.empty(
|
||||
(all_tokens, K // 128),
|
||||
device=hidden_states_fp8.device,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
),
|
||||
]
|
||||
m_indices = torch.empty(
|
||||
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
|
||||
all_tokens, device=hidden_states.device, dtype=torch.int32
|
||||
)
|
||||
output_index = torch.empty_like(topk_idx)
|
||||
output_index = torch.empty_like(topk_ids)
|
||||
|
||||
if get_offloader().forbid_copy_engine_usage:
|
||||
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
||||
@@ -367,9 +333,9 @@ class DeepEPMoE(FusedMoE):
|
||||
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
||||
|
||||
ep_scatter(
|
||||
hidden_states_fp8,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
num_recv_tokens_per_expert_gpu,
|
||||
expert_start_loc,
|
||||
input_tensor[0],
|
||||
@@ -378,11 +344,11 @@ class DeepEPMoE(FusedMoE):
|
||||
output_index,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8)
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
gateup_output = torch.empty(
|
||||
(all_tokens, N),
|
||||
device=hidden_states_fp8_device,
|
||||
device=hidden_states_device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
@@ -403,7 +369,7 @@ class DeepEPMoE(FusedMoE):
|
||||
del gateup_output
|
||||
down_output = torch.empty(
|
||||
(all_tokens, K),
|
||||
device=hidden_states_fp8_device,
|
||||
device=hidden_states_device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
||||
@@ -425,11 +391,11 @@ class DeepEPMoE(FusedMoE):
|
||||
del down_input_fp8, down_input_scale
|
||||
|
||||
gather_out = torch.empty(
|
||||
hidden_states_fp8_shape,
|
||||
device=hidden_states_fp8_device,
|
||||
hidden_states_shape,
|
||||
device=hidden_states_device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
||||
ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
|
||||
|
||||
return gather_out
|
||||
|
||||
@@ -438,13 +404,13 @@ class DeepEPMoE(FusedMoE):
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
||||
):
|
||||
hidden_states, _, _, masked_m, _ = dispatch_output
|
||||
hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
|
||||
output = self.quant_method.apply_without_routing_weights(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
x=(hidden_states, hidden_states_scale),
|
||||
masked_m=masked_m,
|
||||
moe_runner_config=self.moe_runner_config,
|
||||
down_gemm_overlap_args=down_gemm_overlap_args,
|
||||
@@ -466,25 +432,28 @@ class DeepEPMoE(FusedMoE):
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
):
|
||||
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
||||
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
assert self.moe_runner_config.activation == "silu"
|
||||
assert (
|
||||
hidden_states_scale.dtype == torch.float32
|
||||
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
|
||||
|
||||
# GroupGemm-0
|
||||
num_groups, m, k = hidden_states_fp8[0].size()
|
||||
num_groups, m, k = hidden_states.size()
|
||||
n = self.w13_weight.size(1)
|
||||
expected_m = min(expected_m, m)
|
||||
gateup_output = torch.empty(
|
||||
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
||||
(num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
|
||||
)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||
hidden_states_fp8,
|
||||
(hidden_states, hidden_states_scale),
|
||||
self.w13_weight_fp8,
|
||||
gateup_output,
|
||||
masked_m,
|
||||
expected_m,
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8[0])
|
||||
dispose_tensor(hidden_states)
|
||||
|
||||
# Act
|
||||
down_input = torch.empty(
|
||||
@@ -557,11 +526,9 @@ class DeepEPMoE(FusedMoE):
|
||||
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
||||
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
|
||||
dispatch_output
|
||||
)
|
||||
|
||||
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
||||
hidden_states.device
|
||||
@@ -571,7 +538,7 @@ class DeepEPMoE(FusedMoE):
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||
# per_token_scale=[per_token_scale],
|
||||
# per_token_scale=[hidden_states_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
@@ -591,7 +558,7 @@ class DeepEPMoE(FusedMoE):
|
||||
)[0]
|
||||
else:
|
||||
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states
|
||||
)
|
||||
# gmm1: gate_up_proj
|
||||
@@ -599,7 +566,7 @@ class DeepEPMoE(FusedMoE):
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight],
|
||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||
per_token_scale=[per_token_scale],
|
||||
per_token_scale=[hidden_states_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
@@ -631,11 +598,14 @@ class DeepEPMoE(FusedMoE):
|
||||
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(dispatch_output, DeepEPLLOutput)
|
||||
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
per_token_scale = hidden_states[1]
|
||||
hidden_states = hidden_states[0]
|
||||
(
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
group_list,
|
||||
_,
|
||||
) = dispatch_output
|
||||
|
||||
group_list = group_list.to(torch.int64)
|
||||
|
||||
@@ -644,7 +614,7 @@ class DeepEPMoE(FusedMoE):
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||
# per_token_scale=[per_token_scale],
|
||||
# per_token_scale=[hidden_states_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
@@ -678,7 +648,7 @@ class DeepEPMoE(FusedMoE):
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||
activation_scale=per_token_scale,
|
||||
activation_scale=hidden_states_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
|
||||
@@ -11,14 +11,19 @@ from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_moe_tensor_parallel_rank,
|
||||
get_moe_tensor_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.layers.moe import (
|
||||
MoeRunnerConfig,
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
get_moe_runner_backend,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
StandardDispatcher,
|
||||
StandardDispatchOutput,
|
||||
@@ -32,6 +37,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
@@ -71,6 +77,27 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
|
||||
a2a_backend = get_moe_a2a_backend()
|
||||
if a2a_backend.is_none():
|
||||
return StandardDispatcher(moe_runner_config)
|
||||
elif a2a_backend.is_deepep():
|
||||
return MaybeTboDeepEPDispatcher(
|
||||
group=get_tp_group().device_group,
|
||||
router_topk=moe_runner_config.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=moe_runner_config.num_experts,
|
||||
num_local_experts=moe_runner_config.num_local_experts,
|
||||
hidden_size=moe_runner_config.hidden_size,
|
||||
params_dtype=moe_runner_config.params_dtype,
|
||||
deepep_mode=get_deepep_mode(),
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}")
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
@@ -132,8 +159,6 @@ class FusedMoE(torch.nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.num_experts = num_experts
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.expert_map_cpu = None
|
||||
self.expert_map_gpu = None
|
||||
|
||||
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
|
||||
@@ -149,19 +174,6 @@ class FusedMoE(torch.nn.Module):
|
||||
assert num_experts % self.moe_ep_size == 0
|
||||
self.num_local_experts = num_experts // self.moe_ep_size
|
||||
|
||||
if self.moe_ep_size > 1:
|
||||
# TODO(ch-wan): support shared experts fusion
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.expert_map_cpu = torch.full(
|
||||
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
# Create a expert map for the local experts
|
||||
self.expert_map_cpu[
|
||||
self.moe_ep_rank
|
||||
* self.num_local_experts : (self.moe_ep_rank + 1)
|
||||
* self.num_local_experts
|
||||
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
||||
|
||||
assert intermediate_size % self.moe_tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
||||
self.reduce_results = reduce_results
|
||||
@@ -219,7 +231,7 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
||||
self.dispatcher = StandardDispatcher()
|
||||
self.dispatcher = create_moe_dispatcher(self.moe_runner_config)
|
||||
|
||||
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
|
||||
self.quant_method, ModelOptNvFp4FusedMoEMethod
|
||||
@@ -453,9 +465,12 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
||||
if self.expert_map_cpu is None:
|
||||
return expert_id
|
||||
return self.expert_map_cpu[expert_id].item()
|
||||
start_idx = self.moe_ep_rank * self.num_local_experts
|
||||
end_idx = (self.moe_ep_rank + 1) * self.num_local_experts
|
||||
if start_idx <= expert_id < end_idx:
|
||||
return expert_id - start_idx
|
||||
else:
|
||||
return -1
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
@@ -804,32 +819,18 @@ class FusedMoE(torch.nn.Module):
|
||||
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||
assert self.quant_method is not None
|
||||
|
||||
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
||||
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
|
||||
# If we are in EP mode, we need to move the expert map to GPU.
|
||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
||||
|
||||
if self.expert_map_gpu is not None:
|
||||
if TopKOutputChecker.format_is_standard(topk_output):
|
||||
topk_output = topk_output._replace(
|
||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
||||
)
|
||||
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
||||
raise NotImplementedError()
|
||||
|
||||
dispatch_output = self.dispatcher.dispatch(
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
|
||||
# TODO: consider using symmetric memory
|
||||
combine_input = self.quant_method.apply(
|
||||
layer=self,
|
||||
combine_input = self.run_moe_core(
|
||||
dispatch_output=dispatch_output,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
final_hidden_states = self.dispatcher.combine(combine_input)
|
||||
|
||||
# TODO: should we add some conditions here?
|
||||
final_hidden_states = final_hidden_states[
|
||||
..., :origin_hidden_states_dim
|
||||
].contiguous()
|
||||
@@ -839,6 +840,14 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def run_moe_core(self, dispatch_output: DispatchOutput, **kwargs) -> CombineInput:
|
||||
# TODO: consider using symmetric memory
|
||||
return self.quant_method.apply(
|
||||
layer=self,
|
||||
dispatch_output=dispatch_output,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
cls,
|
||||
|
||||
@@ -23,6 +23,7 @@ from sglang.srt.layers.moe.token_dispatcher.mooncake import (
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
StandardCombineInput,
|
||||
StandardDispatcher,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
@@ -38,6 +39,7 @@ __all__ = [
|
||||
"MooncakeCombineInput",
|
||||
"MooncakeDispatchOutput",
|
||||
"MooncakeEPDispatcher",
|
||||
"StandardDispatcher",
|
||||
"StandardDispatchOutput",
|
||||
"StandardCombineInput",
|
||||
"DeepEPConfig",
|
||||
|
||||
@@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum):
|
||||
class DispatchOutput(Protocol):
|
||||
"""Protocol for dispatch outputs in different formats."""
|
||||
|
||||
# TODO: add hidden_states to the protocol
|
||||
hidden_states: torch.Tensor
|
||||
|
||||
@property
|
||||
def format(self) -> DispatchOutputFormat: ...
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
BaseDispatcher,
|
||||
BaseDispatcherConfig,
|
||||
@@ -15,6 +16,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.utils import (
|
||||
DeepEPMode,
|
||||
get_deepep_config,
|
||||
@@ -51,8 +53,6 @@ from enum import Enum, IntEnum, auto
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -61,9 +61,9 @@ logger = logging.getLogger(__name__)
|
||||
class DeepEPNormalOutput(NamedTuple):
|
||||
"""DeepEP normal dispatch output."""
|
||||
|
||||
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
||||
# hidden_states_scale
|
||||
topk_idx: torch.Tensor
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: Optional[torch.Tensor]
|
||||
topk_ids: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
num_recv_tokens_per_expert: List[int]
|
||||
|
||||
@@ -75,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple):
|
||||
class DeepEPLLOutput(NamedTuple):
|
||||
"""DeepEP low latency dispatch output."""
|
||||
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||
topk_idx: torch.Tensor
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: Optional[torch.Tensor]
|
||||
topk_ids: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
masked_m: torch.Tensor
|
||||
expected_m: int
|
||||
@@ -314,9 +315,7 @@ class _DeepEPDispatcherImplBase:
|
||||
def dispatch_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: Optional[torch.Tensor],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -326,7 +325,7 @@ class _DeepEPDispatcherImplBase:
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
@@ -345,15 +344,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
|
||||
self.async_finish = async_finish
|
||||
self.src2dst = None
|
||||
self.quant_config = {}
|
||||
|
||||
def dispatch_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: Optional[torch.Tensor],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
||||
topk_ids = topk_ids.to(torch.int64)
|
||||
if (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
and not get_moe_runner_backend().is_cutlass()
|
||||
@@ -367,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return hidden_states, topk_idx, topk_weights, previous_event
|
||||
return hidden_states, topk_ids, topk_weights, previous_event
|
||||
|
||||
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
||||
def dispatch_b(self, hidden_states, topk_ids, topk_weights, previous_event):
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert,
|
||||
event,
|
||||
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
||||
) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event)
|
||||
event.current_stream_wait() if self.async_finish else ()
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, hidden_states_scale = hidden_states
|
||||
else:
|
||||
hidden_states_scale = None
|
||||
|
||||
return DeepEPNormalOutput(
|
||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
num_recv_tokens_per_expert,
|
||||
)
|
||||
|
||||
def _dispatch_core(
|
||||
self,
|
||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
previous_event,
|
||||
):
|
||||
@@ -397,7 +406,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
is_token_in_rank,
|
||||
previous_event,
|
||||
) = buffer.get_dispatch_layout(
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
self.num_experts,
|
||||
previous_event=previous_event,
|
||||
async_finish=self.async_finish,
|
||||
@@ -409,14 +418,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_ids,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert,
|
||||
self.handle,
|
||||
event,
|
||||
) = buffer.dispatch(
|
||||
x,
|
||||
topk_idx=topk_idx,
|
||||
topk_idx=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
num_tokens_per_rank=num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
@@ -437,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
|
||||
return (
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_ids,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert,
|
||||
event,
|
||||
@@ -446,40 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_post_reorder_triton_kernel,
|
||||
)
|
||||
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||
output = hidden_states
|
||||
else:
|
||||
if hidden_states.shape[0] > 0:
|
||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||
output = torch.empty(
|
||||
(num_tokens, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||
hidden_states,
|
||||
output,
|
||||
self.src2dst,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.router_topk,
|
||||
hidden_states.shape[1],
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
output = torch.zeros(
|
||||
(0, hidden_states.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
raise NotImplementedError() # triton runner was supported but it's temporarily disabled
|
||||
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return output, previous_event
|
||||
|
||||
@@ -514,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
self.num_experts,
|
||||
)
|
||||
|
||||
def set_quant_config(self, quant_config: dict):
|
||||
self.quant_config = quant_config
|
||||
|
||||
|
||||
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
def __init__(self, return_recv_hook: bool, **kwargs):
|
||||
@@ -525,28 +513,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
"""
|
||||
self.return_recv_hook = return_recv_hook
|
||||
self.device_module = torch.get_device_module()
|
||||
self.quant_config = {}
|
||||
|
||||
def dispatch_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: Optional[torch.Tensor],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
buffer = self._get_buffer()
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
||||
topk_ids = topk_ids.to(torch.int64)
|
||||
expected_m = (
|
||||
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
||||
hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
|
||||
+ self.num_experts
|
||||
) // self.num_experts
|
||||
hidden_states, masked_m, event, hook = self._dispatch_core(
|
||||
hidden_states,
|
||||
input_global_scale,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
)
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
@@ -557,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
def dispatch_b(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
@@ -570,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
masked_m
|
||||
)
|
||||
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states, hidden_states_scale = hidden_states
|
||||
else:
|
||||
hidden_states_scale = None
|
||||
|
||||
deepep_output = DeepEPLLOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
hidden_states_scale,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
@@ -582,10 +575,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
def _dispatch_core(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: Optional[torch.Tensor],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
use_nvfp4 = use_fp8 = False
|
||||
input_global_scale = self.quant_config.get("input_global_scale", None)
|
||||
if input_global_scale is not None:
|
||||
use_nvfp4 = True
|
||||
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
|
||||
@@ -595,7 +588,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
||||
buffer.low_latency_dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.num_experts,
|
||||
use_fp8=use_fp8,
|
||||
@@ -618,13 +611,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
hidden_states, event, hook = self._combine_core(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
@@ -644,7 +637,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
def _combine_core(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
@@ -658,7 +651,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
with ctx:
|
||||
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
||||
x=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_idx=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
handle=self.handle,
|
||||
async_finish=not self.return_recv_hook,
|
||||
@@ -688,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
self.num_experts,
|
||||
)
|
||||
|
||||
def set_quant_config(self, quant_config: dict):
|
||||
self.quant_config = quant_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Stage(Enum):
|
||||
@@ -745,25 +741,20 @@ class DeepEPDispatcher(BaseDispatcher):
|
||||
def dispatch_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: Optional[torch.Tensor],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||
inner_state = self._get_impl(forward_batch).dispatch_a(
|
||||
inner_state = self._get_impl().dispatch_a(
|
||||
hidden_states=hidden_states,
|
||||
input_global_scale=input_global_scale,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
self._dispatch_intermediate_state = forward_batch, inner_state
|
||||
self._dispatch_intermediate_state = inner_state
|
||||
|
||||
def dispatch_b(self):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||
forward_batch, inner_state = self._dispatch_intermediate_state
|
||||
inner_state = self._dispatch_intermediate_state
|
||||
del self._dispatch_intermediate_state
|
||||
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
||||
return self._get_impl().dispatch_b(*inner_state)
|
||||
|
||||
def combine(self, *args, **kwargs) -> Tuple:
|
||||
self.combine_a(*args, **kwargs)
|
||||
@@ -773,30 +764,28 @@ class DeepEPDispatcher(BaseDispatcher):
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
overlap_args: Optional["CombineOverlapArgs"] = None,
|
||||
):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||
inner_state = self._get_impl(forward_batch).combine_a(
|
||||
inner_state = self._get_impl().combine_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
self._combine_intermediate_state = forward_batch, inner_state
|
||||
self._combine_intermediate_state = inner_state
|
||||
|
||||
def combine_b(self):
|
||||
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||
forward_batch, inner_state = self._combine_intermediate_state
|
||||
inner_state = self._combine_intermediate_state
|
||||
del self._combine_intermediate_state
|
||||
return self._get_impl(forward_batch).combine_b(*inner_state)
|
||||
return self._get_impl().combine_b(*inner_state)
|
||||
|
||||
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||
forward_batch.is_extend_in_batch
|
||||
)
|
||||
def _get_impl(self) -> _DeepEPDispatcherImplBase:
|
||||
is_extend_in_batch = get_is_extend_in_batch()
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
|
||||
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
||||
return self._normal_dispatcher
|
||||
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
||||
@@ -807,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher):
|
||||
def _update_stage(self, old_stage, new_stage):
|
||||
assert self._stage == old_stage
|
||||
self._stage = new_stage
|
||||
|
||||
def set_quant_config(self, quant_config: dict):
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
self._low_latency_dispatcher.set_quant_config(quant_config)
|
||||
if self.deepep_mode.enable_normal():
|
||||
self._normal_dispatcher.set_quant_config(quant_config)
|
||||
|
||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional, Tuple
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
BaseDispatcher,
|
||||
CombineInput,
|
||||
@@ -12,6 +13,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.utils import get_int_env_var
|
||||
|
||||
@@ -27,16 +29,15 @@ from enum import Enum, auto
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MooncakeDispatchOutput(NamedTuple):
|
||||
"""Mooncake EP dispatch output."""
|
||||
|
||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
||||
topk_idx: torch.Tensor
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
topk_weights: torch.Tensor
|
||||
masked_m: torch.Tensor
|
||||
expected_m: int
|
||||
@@ -164,23 +165,23 @@ class _MooncakeEPDispatcherImpl:
|
||||
def dispatch_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
topk_ids, topk_weights = topk_output.topk_ids, topk_output.topk_weights
|
||||
buffer = self._get_buffer()
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
topk_ids = topk_ids.to(torch.int64)
|
||||
expected_m = (
|
||||
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
||||
hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
|
||||
+ self.num_experts
|
||||
) // self.num_experts
|
||||
hidden_states, masked_m, event, hook = self._dispatch_core(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
use_fp8=True,
|
||||
)
|
||||
return (
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
@@ -191,7 +192,7 @@ class _MooncakeEPDispatcherImpl:
|
||||
def dispatch_b(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
@@ -206,7 +207,7 @@ class _MooncakeEPDispatcherImpl:
|
||||
|
||||
return MooncakeDispatchOutput(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
masked_m,
|
||||
expected_m,
|
||||
@@ -215,14 +216,14 @@ class _MooncakeEPDispatcherImpl:
|
||||
def _dispatch_core(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8: bool = False,
|
||||
):
|
||||
buffer = self._get_buffer()
|
||||
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
||||
buffer.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
self.active_ranks,
|
||||
self.num_max_dispatch_tokens_per_rank,
|
||||
self.num_experts,
|
||||
@@ -237,12 +238,12 @@ class _MooncakeEPDispatcherImpl:
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
hidden_states, event, hook = self._combine_core(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
)
|
||||
return hidden_states, event, hook
|
||||
@@ -254,13 +255,13 @@ class _MooncakeEPDispatcherImpl:
|
||||
def _combine_core(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
):
|
||||
buffer = self._get_buffer()
|
||||
combined_hidden_states, event, hook = buffer.combine(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
self.active_ranks,
|
||||
-1 if self.first_execution else self.timeout_us,
|
||||
@@ -332,24 +333,20 @@ class MooncakeEPDispatcher(BaseDispatcher):
|
||||
def dispatch_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_global_scale: Optional[torch.Tensor],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||
inner_state = self._get_impl(forward_batch).dispatch_a(
|
||||
inner_state = self._get_impl().dispatch_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
self._dispatch_intermediate_state = forward_batch, inner_state
|
||||
self._dispatch_intermediate_state = inner_state
|
||||
|
||||
def dispatch_b(self):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||
forward_batch, inner_state = self._dispatch_intermediate_state
|
||||
inner_state = self._dispatch_intermediate_state
|
||||
del self._dispatch_intermediate_state
|
||||
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
||||
return self._get_impl().dispatch_b(*inner_state)
|
||||
|
||||
def combine(self, *args, **kwargs) -> Tuple:
|
||||
self.combine_a(*args, **kwargs)
|
||||
@@ -359,29 +356,27 @@ class MooncakeEPDispatcher(BaseDispatcher):
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
overlap_args: Optional = None,
|
||||
):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||
inner_state = self._get_impl(forward_batch).combine_a(
|
||||
inner_state = self._get_impl().combine_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
)
|
||||
self._combine_intermediate_state = forward_batch, inner_state
|
||||
self._combine_intermediate_state = inner_state
|
||||
|
||||
def combine_b(self):
|
||||
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||
forward_batch, inner_state = self._combine_intermediate_state
|
||||
inner_state = self._combine_intermediate_state
|
||||
del self._combine_intermediate_state
|
||||
return self._get_impl(forward_batch).combine_b(*inner_state)
|
||||
return self._get_impl().combine_b(*inner_state)
|
||||
|
||||
def _get_impl(self, forward_batch: ForwardBatch) -> _MooncakeEPDispatcherImpl:
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||
forward_batch.is_extend_in_batch
|
||||
)
|
||||
def _get_impl(self) -> _MooncakeEPDispatcherImpl:
|
||||
is_extend_in_batch = get_is_extend_in_batch()
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
|
||||
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
||||
raise NotImplementedError
|
||||
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
||||
@@ -392,3 +387,6 @@ class MooncakeEPDispatcher(BaseDispatcher):
|
||||
def _update_stage(self, old_stage, new_stage):
|
||||
assert self._stage == old_stage
|
||||
self._stage = new_stage
|
||||
|
||||
def set_quant_config(self, quant_config: dict):
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_rank,
|
||||
get_moe_expert_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
BaseDispatcher,
|
||||
CombineInput,
|
||||
@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
||||
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
|
||||
|
||||
class StandardDispatcher(BaseDispatcher):
|
||||
|
||||
def __init__(self, moe_runner_config: MoeRunnerConfig):
|
||||
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
||||
self.enable_flashinfer_cutlass_moe = (
|
||||
get_moe_runner_backend().is_flashinfer_cutlass()
|
||||
)
|
||||
self.num_experts = moe_runner_config.num_experts
|
||||
self.num_local_experts = moe_runner_config.num_local_experts
|
||||
self.moe_ep_rank = get_moe_expert_parallel_rank()
|
||||
self.local_expert_mapping = None
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor, topk_output: TopKOutput
|
||||
) -> DispatchOutput:
|
||||
|
||||
if (
|
||||
self.moe_ep_size > 1
|
||||
and not self.enable_flashinfer_cutlass_moe
|
||||
and TopKOutputChecker.format_is_standard(topk_output)
|
||||
):
|
||||
if self.local_expert_mapping is None:
|
||||
self.local_expert_mapping = torch.full(
|
||||
(self.num_experts,), -1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.local_expert_mapping[
|
||||
self.moe_ep_rank
|
||||
* self.num_local_experts : (self.moe_ep_rank + 1)
|
||||
* self.num_local_experts
|
||||
] = torch.arange(
|
||||
0, self.num_local_experts, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
if self.local_expert_mapping is not None:
|
||||
if TopKOutputChecker.format_is_standard(topk_output):
|
||||
topk_output = topk_output._replace(
|
||||
topk_ids=self.local_expert_mapping[topk_output.topk_ids]
|
||||
)
|
||||
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
||||
raise NotImplementedError()
|
||||
|
||||
return StandardDispatchOutput(
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
|
||||
# TODO: this branch should be removed in the future
|
||||
assert isinstance(combine_input, torch.Tensor)
|
||||
return combine_input
|
||||
|
||||
def set_quant_config(self, quant_config: dict):
|
||||
pass
|
||||
|
||||
@@ -365,9 +365,10 @@ class TopK(CustomOp):
|
||||
def empty_topk_output(self, device: torch.device) -> TopKOutput:
|
||||
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
|
||||
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
|
||||
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
|
||||
topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
|
||||
# FIXME: router_logits should be of size (0, num_experts)
|
||||
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
|
||||
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
|
||||
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
|
||||
|
||||
|
||||
# ------------------------------- TopK implementation -------------------------------------
|
||||
|
||||
@@ -1244,6 +1244,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
layer.dispatcher.set_quant_config(
|
||||
{"input_global_scale": layer.w13_input_scale_quant}
|
||||
)
|
||||
|
||||
# Validate weight scales
|
||||
for name, weight_scale in [
|
||||
("w13", layer.w13_weight_scale),
|
||||
|
||||
@@ -339,7 +339,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
hidden_states, topk_idx, topk_weights = (
|
||||
dispatch_output.hidden_states,
|
||||
dispatch_output.topk_idx,
|
||||
dispatch_output.topk_ids,
|
||||
dispatch_output.topk_weights,
|
||||
)
|
||||
if isinstance(hidden_states, tuple):
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
set_is_extend_in_batch,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
|
||||
@@ -639,6 +640,7 @@ class CudaGraphRunner:
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
set_is_extend_in_batch(False)
|
||||
|
||||
kwargs = {}
|
||||
if (
|
||||
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_dp_rank,
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
set_is_extend_in_batch,
|
||||
)
|
||||
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
|
||||
|
||||
@@ -688,6 +689,7 @@ class ForwardBatch:
|
||||
|
||||
self.global_dp_buffer_len = buffer_len
|
||||
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
|
||||
set_is_extend_in_batch(self.is_extend_in_batch)
|
||||
|
||||
bs = self.batch_size
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
set_is_extend_in_batch,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||
@@ -377,6 +378,9 @@ class PiecewiseCudaGraphRunner:
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
# FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode.
|
||||
# It is True in this context but we need to set it to use low latency deepep mode.
|
||||
set_is_extend_in_batch(False)
|
||||
|
||||
kwargs = {}
|
||||
with set_forward_context(forward_batch, self.attention_layers):
|
||||
|
||||
@@ -380,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
||||
if self.num_shared_experts > 0:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
topk_output = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
@@ -389,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
topk_idx = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if self.ep_size > 1:
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
num_recv_tokens_per_expert,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
) = self.deepep_dispatcher.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_batch=forward_batch,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
final_hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
final_hidden_states += shared_output
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,6 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import (
|
||||
get_deepep_mode,
|
||||
get_moe_a2a_backend,
|
||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
@@ -112,10 +111,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.single_batch_overlap import SboFlags
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
MaybeTboDeepEPDispatcher,
|
||||
model_forward_maybe_tbo,
|
||||
)
|
||||
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
LazyValue,
|
||||
@@ -649,19 +645,6 @@ class DeepseekV2MoE(nn.Module):
|
||||
else None
|
||||
)
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=get_deepep_mode(),
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
self._enable_a2a_moe = (
|
||||
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
||||
)
|
||||
@@ -874,7 +857,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
router_logits = self.gate(hidden_states)
|
||||
if not self._fuse_shared_experts_inside_sbo:
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
topk_output = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
@@ -883,9 +866,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
||||
hidden_states.device
|
||||
)
|
||||
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||
|
||||
if self._fuse_shared_experts_inside_sbo:
|
||||
shared_output = None
|
||||
@@ -896,9 +877,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
topk_output=topk_output,
|
||||
**(
|
||||
dict(
|
||||
forward_shared_experts=_forward_shared_experts_and_put_results,
|
||||
@@ -960,7 +939,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
||||
state.topk_output = self.topk(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
@@ -969,21 +948,13 @@ class DeepseekV2MoE(nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
state.topk_weights_local = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
state.topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||
|
||||
def op_dispatch_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
self.experts.deepep_dispatcher.dispatch_a(
|
||||
self.experts.dispatcher.dispatch_a(
|
||||
hidden_states=state.hidden_states_mlp_input,
|
||||
input_global_scale=None,
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
forward_batch=state.forward_batch,
|
||||
topk_output=state.pop("topk_output"),
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
@@ -992,32 +963,29 @@ class DeepseekV2MoE(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
||||
state.dispatch_output = self.experts.dispatcher.dispatch_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_experts(self, state):
|
||||
state.hidden_states_experts_output = self.experts.moe_impl(
|
||||
state.hidden_states_experts_output = self.experts.run_moe_core(
|
||||
dispatch_output=state.dispatch_output,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
self.experts.deepep_dispatcher.combine_a(
|
||||
self.experts.dispatcher.combine_a(
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.dispatch_output.topk_idx,
|
||||
topk_ids=state.dispatch_output.topk_ids,
|
||||
topk_weights=state.dispatch_output.topk_weights,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
state.pop("dispatch_output")
|
||||
|
||||
def op_combine_b(self, state):
|
||||
if self.ep_size > 1:
|
||||
state.hidden_states_after_combine = (
|
||||
self.experts.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_output(self, state):
|
||||
|
||||
@@ -27,7 +27,6 @@ from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
parallel_state,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
@@ -49,7 +48,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
@@ -71,7 +70,6 @@ from sglang.srt.models.deepseek_v2 import (
|
||||
DeepseekV2MoE,
|
||||
)
|
||||
from sglang.srt.server_args import get_global_server_args
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
LazyValue,
|
||||
@@ -477,19 +475,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
else None
|
||||
)
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
num_experts=self.num_experts,
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=get_deepep_mode(),
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
self._enable_a2a_moe = (
|
||||
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
||||
)
|
||||
|
||||
@@ -219,7 +219,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
topk_output = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
@@ -228,14 +228,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
||||
hidden_states.device
|
||||
)
|
||||
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
|
||||
if shared_output is not None:
|
||||
|
||||
@@ -180,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
if hidden_states.shape[0] > 0:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
topk_output = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
@@ -189,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
topk_idx = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
topk_output=topk_output,
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
@@ -219,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
||||
state.topk_output = self.topk(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
@@ -228,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
state.topk_weights_local = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
state.topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||
|
||||
def op_dispatch_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
self.experts.deepep_dispatcher.dispatch_a(
|
||||
self.experts.dispatcher.dispatch_a(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
forward_batch=state.forward_batch,
|
||||
topk_output=state.pop("topk_output"),
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
@@ -250,32 +236,29 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
||||
state.dispatch_output = self.experts.dispatcher.dispatch_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_experts(self, state):
|
||||
state.hidden_states_experts_output = self.experts.moe_impl(
|
||||
state.hidden_states_experts_output = self.experts.run_moe_core(
|
||||
dispatch_output=state.dispatch_output,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
self.experts.deepep_dispatcher.combine_a(
|
||||
self.experts.dispatcher.combine_a(
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.dispatch_output.topk_idx,
|
||||
topk_ids=state.dispatch_output.topk_ids,
|
||||
topk_weights=state.dispatch_output.topk_weights,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
state.pop("dispatch_output")
|
||||
|
||||
def op_combine_b(self, state):
|
||||
if self.ep_size > 1:
|
||||
state.hidden_states_after_combine = (
|
||||
self.experts.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_output(self, state):
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
@@ -5,12 +21,12 @@ import torch
|
||||
|
||||
from sglang.srt.layers import deep_gemm_wrapper
|
||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_int_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
|
||||
class SboFlags:
|
||||
@@ -54,23 +70,22 @@ class DownGemmOverlapArgs:
|
||||
|
||||
def execute_sbo(
|
||||
forward_shared_experts: Callable[[], Any],
|
||||
experts: "DeepEPMoE",
|
||||
experts: FusedMoE,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
alt_stream: Optional = None,
|
||||
topk_output: TopKOutput,
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
disable_sbo: bool = False,
|
||||
):
|
||||
dispatch_output = experts.dispatch(
|
||||
hidden_states, topk_idx, topk_weights, forward_batch
|
||||
|
||||
dispatch_output = experts.dispatcher.dispatch(
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
|
||||
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
||||
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
|
||||
)
|
||||
|
||||
hidden_states = experts.moe_impl(
|
||||
hidden_states = experts.run_moe_core(
|
||||
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
||||
)
|
||||
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
||||
@@ -83,11 +98,10 @@ def execute_sbo(
|
||||
):
|
||||
forward_shared_experts()
|
||||
|
||||
hidden_states = experts.combine(
|
||||
hidden_states,
|
||||
dispatch_output.topk_idx,
|
||||
dispatch_output.topk_weights,
|
||||
forward_batch,
|
||||
hidden_states = experts.dispatcher.combine(
|
||||
hidden_states=hidden_states,
|
||||
topk_ids=dispatch_output.topk_ids,
|
||||
topk_weights=dispatch_output.topk_weights,
|
||||
overlap_args=combine_overlap_args,
|
||||
)
|
||||
|
||||
@@ -101,9 +115,7 @@ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
|
||||
):
|
||||
return None, None, {}
|
||||
|
||||
hidden_states = dispatch_output.hidden_states_fp8
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
hidden_states = dispatch_output.hidden_states
|
||||
|
||||
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
get_global_graph_memory_pool,
|
||||
model_capture_mode,
|
||||
set_global_graph_memory_pool,
|
||||
set_is_extend_in_batch,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
@@ -263,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
set_is_extend_in_batch(False)
|
||||
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
|
||||
@@ -15,6 +15,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
get_global_graph_memory_pool,
|
||||
model_capture_mode,
|
||||
set_global_graph_memory_pool,
|
||||
set_is_extend_in_batch,
|
||||
set_torch_compile_config,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
@@ -294,6 +295,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
# Clean intermediate result cache for DP attention
|
||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||
set_is_extend_in_batch(False)
|
||||
|
||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||
|
||||
@@ -1000,3 +1000,7 @@ class MaybeTboDeepEPDispatcher:
|
||||
|
||||
def combine_b(self, **kwargs):
|
||||
return self._execute("combine_b", **kwargs)
|
||||
|
||||
def set_quant_config(self, quant_config: dict):
|
||||
for inner in self._inners:
|
||||
inner.set_quant_config(quant_config)
|
||||
|
||||
Reference in New Issue
Block a user