[9/N] MoE Refactor: cleanup dispatcher interfaces (#11847)
This commit is contained in:
@@ -87,6 +87,7 @@ class _DpGatheredBufferWrapper:
|
|||||||
_global_dp_buffer_len: int
|
_global_dp_buffer_len: int
|
||||||
_local_dp_buffer_len: int
|
_local_dp_buffer_len: int
|
||||||
_global_num_tokens: Optional[List[int]]
|
_global_num_tokens: Optional[List[int]]
|
||||||
|
_is_extend_in_batch: bool
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
||||||
@@ -145,6 +146,14 @@ class _DpGatheredBufferWrapper:
|
|||||||
def get_dp_device(cls) -> torch.device:
|
def get_dp_device(cls) -> torch.device:
|
||||||
return cls._device
|
return cls._device
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
|
||||||
|
cls._is_extend_in_batch = is_extend_in_batch
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_is_extend_in_batch(cls) -> bool:
|
||||||
|
return cls._is_extend_in_batch
|
||||||
|
|
||||||
|
|
||||||
def set_dp_buffer_len(
|
def set_dp_buffer_len(
|
||||||
global_dp_buffer_len: int,
|
global_dp_buffer_len: int,
|
||||||
@@ -188,6 +197,14 @@ def get_dp_device() -> torch.device:
|
|||||||
return _DpGatheredBufferWrapper.get_dp_device()
|
return _DpGatheredBufferWrapper.get_dp_device()
|
||||||
|
|
||||||
|
|
||||||
|
def set_is_extend_in_batch(is_extend_in_batch: bool):
|
||||||
|
_DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch)
|
||||||
|
|
||||||
|
|
||||||
|
def get_is_extend_in_batch() -> bool:
|
||||||
|
return _DpGatheredBufferWrapper.get_is_extend_in_batch()
|
||||||
|
|
||||||
|
|
||||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||||
if not enable_dp_attention:
|
if not enable_dp_attention:
|
||||||
return tp_rank, tp_size, 0
|
return tp_rank, tp_size, 0
|
||||||
|
|||||||
@@ -566,7 +566,9 @@ def ep_scatter(
|
|||||||
scale_hidden_size = ceil_div(scale_hidden_size, 4)
|
scale_hidden_size = ceil_div(scale_hidden_size, 4)
|
||||||
|
|
||||||
assert m_indices.shape[0] % BLOCK_E == 0
|
assert m_indices.shape[0] % BLOCK_E == 0
|
||||||
assert recv_x_scale.dtype == output_tensor_scale.dtype
|
assert (
|
||||||
|
recv_x_scale.dtype == output_tensor_scale.dtype
|
||||||
|
), f"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}"
|
||||||
assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
|
assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
|
||||||
|
|
||||||
_fwd_kernel_ep_scatter_1[(grid,)](
|
_fwd_kernel_ep_scatter_1[(grid,)](
|
||||||
|
|||||||
@@ -20,18 +20,14 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
tma_align_input_scale,
|
tma_align_input_scale,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
is_fp8_fnuz,
|
is_fp8_fnuz,
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
|
||||||
CUTEDSL_MOE_NVFP4_DISPATCH,
|
|
||||||
ModelOptNvFp4FusedMoEMethod,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
||||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||||
from sglang.srt.utils.offloader import get_offloader
|
from sglang.srt.utils.offloader import get_offloader
|
||||||
@@ -109,23 +105,6 @@ class DeepEPMoE(FusedMoE):
|
|||||||
|
|
||||||
self.deepep_mode = get_deepep_mode()
|
self.deepep_mode = get_deepep_mode()
|
||||||
|
|
||||||
# TODO: move to the beginning of the file
|
|
||||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
|
||||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
|
||||||
|
|
||||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
|
||||||
group=get_tp_group().device_group,
|
|
||||||
router_topk=self.top_k,
|
|
||||||
permute_fusion=True,
|
|
||||||
num_experts=self.num_experts,
|
|
||||||
num_local_experts=self.num_local_experts,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
params_dtype=params_dtype,
|
|
||||||
deepep_mode=self.deepep_mode,
|
|
||||||
async_finish=True, # TODO
|
|
||||||
return_recv_hook=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||||
# NPU supports low_latency deepep without deepgemm
|
# NPU supports low_latency deepep without deepgemm
|
||||||
assert (
|
assert (
|
||||||
@@ -165,19 +144,16 @@ class DeepEPMoE(FusedMoE):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_output: TopKOutput,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
forward_shared_experts=None,
|
forward_shared_experts=None,
|
||||||
alt_stream=None,
|
alt_stream=None,
|
||||||
disable_sbo=False,
|
disable_sbo=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
# We have to call SBO inside MoE to be compatible with hooks used in offloading
|
# We have to call SBO inside MoE to be compatible with hooks used in offloading
|
||||||
return single_batch_overlap.execute_sbo(
|
return single_batch_overlap.execute_sbo(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_output=topk_output,
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
# SBO args
|
# SBO args
|
||||||
experts=self,
|
experts=self,
|
||||||
forward_shared_experts=forward_shared_experts,
|
forward_shared_experts=forward_shared_experts,
|
||||||
@@ -188,25 +164,14 @@ class DeepEPMoE(FusedMoE):
|
|||||||
def dispatch(
|
def dispatch(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_output: TopKOutput,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
):
|
):
|
||||||
return self.deepep_dispatcher.dispatch(
|
return self.dispatcher.dispatch(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_output=topk_output,
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
input_global_scale=(
|
|
||||||
self.w13_input_scale_quant
|
|
||||||
if isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
|
|
||||||
and self.quant_method.enable_flashinfer_cutedsl_moe
|
|
||||||
and CUTEDSL_MOE_NVFP4_DISPATCH
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def moe_impl(
|
def run_moe_core(
|
||||||
self,
|
self,
|
||||||
dispatch_output: DispatchOutput,
|
dispatch_output: DispatchOutput,
|
||||||
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
||||||
@@ -240,16 +205,14 @@ class DeepEPMoE(FusedMoE):
|
|||||||
def combine(
|
def combine(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
overlap_args: Optional[Dict[str, Any]] = None,
|
overlap_args: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
return self.deepep_dispatcher.combine(
|
return self.dispatcher.combine(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_ids=topk_ids,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_batch=forward_batch,
|
|
||||||
overlap_args=overlap_args,
|
overlap_args=overlap_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -257,9 +220,9 @@ class DeepEPMoE(FusedMoE):
|
|||||||
self,
|
self,
|
||||||
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
||||||
):
|
):
|
||||||
hidden_states, topk_idx, topk_weights = (
|
hidden_states, topk_ids, topk_weights = (
|
||||||
dispatch_output.hidden_states,
|
dispatch_output.hidden_states,
|
||||||
dispatch_output.topk_idx,
|
dispatch_output.topk_ids,
|
||||||
dispatch_output.topk_weights,
|
dispatch_output.topk_weights,
|
||||||
)
|
)
|
||||||
if hidden_states.shape[0] == 0:
|
if hidden_states.shape[0] == 0:
|
||||||
@@ -267,15 +230,15 @@ class DeepEPMoE(FusedMoE):
|
|||||||
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
||||||
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
||||||
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
# (idx == num_local_experts) meaning not used in aiter fused_moe
|
||||||
topk_idx_copy = topk_idx.to(torch.int32)
|
topk_ids_copy = topk_ids.to(torch.int32)
|
||||||
topk_idx_copy[topk_idx_copy == -1] = self.num_local_experts
|
topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts
|
||||||
|
|
||||||
return fused_moe(
|
return fused_moe(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.w13_weight,
|
self.w13_weight,
|
||||||
self.w2_weight,
|
self.w2_weight,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_idx_copy,
|
topk_ids_copy,
|
||||||
w1_scale=self.w13_weight_scale_inv,
|
w1_scale=self.w13_weight_scale_inv,
|
||||||
w2_scale=self.w2_weight_scale_inv,
|
w2_scale=self.w2_weight_scale_inv,
|
||||||
quant_type=QuantType.per_128x128,
|
quant_type=QuantType.per_128x128,
|
||||||
@@ -291,18 +254,21 @@ class DeepEPMoE(FusedMoE):
|
|||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPNormalOutput,
|
dispatch_output: DeepEPNormalOutput,
|
||||||
):
|
):
|
||||||
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
(
|
||||||
dispatch_output
|
hidden_states,
|
||||||
)
|
hidden_states_scale,
|
||||||
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
num_recv_tokens_per_expert,
|
||||||
|
) = dispatch_output
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.moe_runner_config.activation == "silu"
|
assert self.moe_runner_config.activation == "silu"
|
||||||
if num_recv_tokens_per_expert is None:
|
if num_recv_tokens_per_expert is None:
|
||||||
return hidden_states_fp8.bfloat16()
|
return hidden_states.bfloat16()
|
||||||
all_tokens = sum(num_recv_tokens_per_expert)
|
all_tokens = sum(num_recv_tokens_per_expert)
|
||||||
if all_tokens <= 0:
|
if all_tokens <= 0:
|
||||||
return hidden_states_fp8.bfloat16()
|
return hidden_states.bfloat16()
|
||||||
M, K = hidden_states_fp8.size()
|
M, K = hidden_states.size()
|
||||||
N = self.w13_weight.size(1)
|
N = self.w13_weight.size(1)
|
||||||
scale_block_size = 128
|
scale_block_size = 128
|
||||||
|
|
||||||
@@ -323,35 +289,35 @@ class DeepEPMoE(FusedMoE):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states_fp8_shape = hidden_states_fp8.shape
|
hidden_states_shape = hidden_states.shape
|
||||||
hidden_states_fp8_device = hidden_states_fp8.device
|
hidden_states_device = hidden_states.device
|
||||||
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
hidden_states_dtype = hidden_states.dtype
|
||||||
|
|
||||||
input_tensor = [
|
input_tensor = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(all_tokens, K),
|
(all_tokens, K),
|
||||||
device=hidden_states_fp8.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states_fp8.dtype,
|
dtype=hidden_states.dtype,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
# TODO check whether need `zeros`
|
# TODO check whether need `zeros`
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
(ceil_div(K // 128, 4), all_tokens),
|
(ceil_div(K // 128, 4), all_tokens),
|
||||||
device=hidden_states_fp8.device,
|
device=hidden_states.device,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
).transpose(0, 1)
|
).transpose(0, 1)
|
||||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||||
else torch.empty(
|
else torch.empty(
|
||||||
(all_tokens, K // 128),
|
(all_tokens, K // 128),
|
||||||
device=hidden_states_fp8.device,
|
device=hidden_states.device,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
m_indices = torch.empty(
|
m_indices = torch.empty(
|
||||||
all_tokens, device=hidden_states_fp8.device, dtype=torch.int32
|
all_tokens, device=hidden_states.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
output_index = torch.empty_like(topk_idx)
|
output_index = torch.empty_like(topk_ids)
|
||||||
|
|
||||||
if get_offloader().forbid_copy_engine_usage:
|
if get_offloader().forbid_copy_engine_usage:
|
||||||
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
||||||
@@ -367,9 +333,9 @@ class DeepEPMoE(FusedMoE):
|
|||||||
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
||||||
|
|
||||||
ep_scatter(
|
ep_scatter(
|
||||||
hidden_states_fp8,
|
hidden_states,
|
||||||
hidden_states_scale,
|
hidden_states_scale,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
num_recv_tokens_per_expert_gpu,
|
num_recv_tokens_per_expert_gpu,
|
||||||
expert_start_loc,
|
expert_start_loc,
|
||||||
input_tensor[0],
|
input_tensor[0],
|
||||||
@@ -378,11 +344,11 @@ class DeepEPMoE(FusedMoE):
|
|||||||
output_index,
|
output_index,
|
||||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||||
)
|
)
|
||||||
dispose_tensor(hidden_states_fp8)
|
dispose_tensor(hidden_states)
|
||||||
|
|
||||||
gateup_output = torch.empty(
|
gateup_output = torch.empty(
|
||||||
(all_tokens, N),
|
(all_tokens, N),
|
||||||
device=hidden_states_fp8_device,
|
device=hidden_states_device,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||||
@@ -403,7 +369,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
del gateup_output
|
del gateup_output
|
||||||
down_output = torch.empty(
|
down_output = torch.empty(
|
||||||
(all_tokens, K),
|
(all_tokens, K),
|
||||||
device=hidden_states_fp8_device,
|
device=hidden_states_device,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
||||||
@@ -425,11 +391,11 @@ class DeepEPMoE(FusedMoE):
|
|||||||
del down_input_fp8, down_input_scale
|
del down_input_fp8, down_input_scale
|
||||||
|
|
||||||
gather_out = torch.empty(
|
gather_out = torch.empty(
|
||||||
hidden_states_fp8_shape,
|
hidden_states_shape,
|
||||||
device=hidden_states_fp8_device,
|
device=hidden_states_device,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
|
||||||
|
|
||||||
return gather_out
|
return gather_out
|
||||||
|
|
||||||
@@ -438,13 +404,13 @@ class DeepEPMoE(FusedMoE):
|
|||||||
dispatch_output: DeepEPLLOutput,
|
dispatch_output: DeepEPLLOutput,
|
||||||
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
||||||
):
|
):
|
||||||
hidden_states, _, _, masked_m, _ = dispatch_output
|
hidden_states, hidden_states_scale, _, _, masked_m, _ = dispatch_output
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.moe_runner_config.activation == "silu"
|
assert self.moe_runner_config.activation == "silu"
|
||||||
|
|
||||||
output = self.quant_method.apply_without_routing_weights(
|
output = self.quant_method.apply_without_routing_weights(
|
||||||
layer=self,
|
layer=self,
|
||||||
x=hidden_states,
|
x=(hidden_states, hidden_states_scale),
|
||||||
masked_m=masked_m,
|
masked_m=masked_m,
|
||||||
moe_runner_config=self.moe_runner_config,
|
moe_runner_config=self.moe_runner_config,
|
||||||
down_gemm_overlap_args=down_gemm_overlap_args,
|
down_gemm_overlap_args=down_gemm_overlap_args,
|
||||||
@@ -466,25 +432,28 @@ class DeepEPMoE(FusedMoE):
|
|||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPLLOutput,
|
dispatch_output: DeepEPLLOutput,
|
||||||
):
|
):
|
||||||
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
hidden_states, hidden_states_scale, _, _, masked_m, expected_m = dispatch_output
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.moe_runner_config.activation == "silu"
|
assert self.moe_runner_config.activation == "silu"
|
||||||
|
assert (
|
||||||
|
hidden_states_scale.dtype == torch.float32
|
||||||
|
), f"hidden_states_scale.dtype: {hidden_states_scale.dtype}"
|
||||||
|
|
||||||
# GroupGemm-0
|
# GroupGemm-0
|
||||||
num_groups, m, k = hidden_states_fp8[0].size()
|
num_groups, m, k = hidden_states.size()
|
||||||
n = self.w13_weight.size(1)
|
n = self.w13_weight.size(1)
|
||||||
expected_m = min(expected_m, m)
|
expected_m = min(expected_m, m)
|
||||||
gateup_output = torch.empty(
|
gateup_output = torch.empty(
|
||||||
(num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16
|
(num_groups, m, n), device=hidden_states.device, dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
||||||
hidden_states_fp8,
|
(hidden_states, hidden_states_scale),
|
||||||
self.w13_weight_fp8,
|
self.w13_weight_fp8,
|
||||||
gateup_output,
|
gateup_output,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
)
|
)
|
||||||
dispose_tensor(hidden_states_fp8[0])
|
dispose_tensor(hidden_states)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
down_input = torch.empty(
|
down_input = torch.empty(
|
||||||
@@ -557,11 +526,9 @@ class DeepEPMoE(FusedMoE):
|
|||||||
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
def _forward_normal(dispatch_output: DeepEPNormalOutput):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
assert isinstance(dispatch_output, DeepEPNormalOutput)
|
||||||
hidden_states, _, _, num_recv_tokens_per_expert = dispatch_output
|
hidden_states, hidden_states_scale, _, _, num_recv_tokens_per_expert = (
|
||||||
|
dispatch_output
|
||||||
if isinstance(hidden_states, tuple):
|
)
|
||||||
per_token_scale = hidden_states[1]
|
|
||||||
hidden_states = hidden_states[0]
|
|
||||||
|
|
||||||
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
group_list = torch.tensor(num_recv_tokens_per_expert, dtype=torch.int64).to(
|
||||||
hidden_states.device
|
hidden_states.device
|
||||||
@@ -571,7 +538,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[self.w13_weight.permute(0, 2, 1)],
|
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||||
# per_token_scale=[per_token_scale],
|
# per_token_scale=[hidden_states_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
@@ -591,7 +558,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
)[0]
|
)[0]
|
||||||
else:
|
else:
|
||||||
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
if not get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT"):
|
||||||
hidden_states, per_token_scale = torch_npu.npu_dynamic_quant(
|
hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant(
|
||||||
hidden_states
|
hidden_states
|
||||||
)
|
)
|
||||||
# gmm1: gate_up_proj
|
# gmm1: gate_up_proj
|
||||||
@@ -599,7 +566,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[self.w13_weight],
|
weight=[self.w13_weight],
|
||||||
scale=[self.w13_weight_scale.to(output_dtype)],
|
scale=[self.w13_weight_scale.to(output_dtype)],
|
||||||
per_token_scale=[per_token_scale],
|
per_token_scale=[hidden_states_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
@@ -631,11 +598,14 @@ class DeepEPMoE(FusedMoE):
|
|||||||
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
def _forward_ll(dispatch_output: DeepEPLLOutput):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert isinstance(dispatch_output, DeepEPLLOutput)
|
assert isinstance(dispatch_output, DeepEPLLOutput)
|
||||||
hidden_states, topk_idx, topk_weights, group_list, _ = dispatch_output
|
(
|
||||||
|
hidden_states,
|
||||||
if isinstance(hidden_states, tuple):
|
hidden_states_scale,
|
||||||
per_token_scale = hidden_states[1]
|
topk_ids,
|
||||||
hidden_states = hidden_states[0]
|
topk_weights,
|
||||||
|
group_list,
|
||||||
|
_,
|
||||||
|
) = dispatch_output
|
||||||
|
|
||||||
group_list = group_list.to(torch.int64)
|
group_list = group_list.to(torch.int64)
|
||||||
|
|
||||||
@@ -644,7 +614,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[self.w13_weight.permute(0, 2, 1)],
|
weight=[self.w13_weight.permute(0, 2, 1)],
|
||||||
# per_token_scale=[per_token_scale],
|
# per_token_scale=[hidden_states_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
@@ -678,7 +648,7 @@ class DeepEPMoE(FusedMoE):
|
|||||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
weight_scale=self.w13_weight_scale.to(torch.float32),
|
weight_scale=self.w13_weight_scale.to(torch.float32),
|
||||||
activation_scale=per_token_scale,
|
activation_scale=hidden_states_scale,
|
||||||
bias=None,
|
bias=None,
|
||||||
quant_scale=None,
|
quant_scale=None,
|
||||||
quant_offset=None,
|
quant_offset=None,
|
||||||
|
|||||||
@@ -11,14 +11,19 @@ from sglang.srt.distributed import (
|
|||||||
get_moe_expert_parallel_world_size,
|
get_moe_expert_parallel_world_size,
|
||||||
get_moe_tensor_parallel_rank,
|
get_moe_tensor_parallel_rank,
|
||||||
get_moe_tensor_parallel_world_size,
|
get_moe_tensor_parallel_world_size,
|
||||||
|
get_tp_group,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.layers.moe import (
|
from sglang.srt.layers.moe import (
|
||||||
MoeRunnerConfig,
|
MoeRunnerConfig,
|
||||||
|
get_deepep_mode,
|
||||||
|
get_moe_a2a_backend,
|
||||||
get_moe_runner_backend,
|
get_moe_runner_backend,
|
||||||
should_use_flashinfer_trtllm_moe,
|
should_use_flashinfer_trtllm_moe,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
|
||||||
|
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
|
||||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||||
StandardDispatcher,
|
StandardDispatcher,
|
||||||
StandardDispatchOutput,
|
StandardDispatchOutput,
|
||||||
@@ -32,6 +37,7 @@ from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
|
|||||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||||
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -71,6 +77,27 @@ def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
|||||||
return tile_tokens_dim
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
|
||||||
|
def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
|
||||||
|
a2a_backend = get_moe_a2a_backend()
|
||||||
|
if a2a_backend.is_none():
|
||||||
|
return StandardDispatcher(moe_runner_config)
|
||||||
|
elif a2a_backend.is_deepep():
|
||||||
|
return MaybeTboDeepEPDispatcher(
|
||||||
|
group=get_tp_group().device_group,
|
||||||
|
router_topk=moe_runner_config.top_k,
|
||||||
|
permute_fusion=True,
|
||||||
|
num_experts=moe_runner_config.num_experts,
|
||||||
|
num_local_experts=moe_runner_config.num_local_experts,
|
||||||
|
hidden_size=moe_runner_config.hidden_size,
|
||||||
|
params_dtype=moe_runner_config.params_dtype,
|
||||||
|
deepep_mode=get_deepep_mode(),
|
||||||
|
async_finish=True,
|
||||||
|
return_recv_hook=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported a2a backend: {a2a_backend}")
|
||||||
|
|
||||||
|
|
||||||
class FusedMoeWeightScaleSupported(Enum):
|
class FusedMoeWeightScaleSupported(Enum):
|
||||||
TENSOR = "tensor"
|
TENSOR = "tensor"
|
||||||
CHANNEL = "channel"
|
CHANNEL = "channel"
|
||||||
@@ -132,8 +159,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
self.num_fused_shared_experts = num_fused_shared_experts
|
||||||
self.expert_map_cpu = None
|
|
||||||
self.expert_map_gpu = None
|
|
||||||
|
|
||||||
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
|
||||||
|
|
||||||
@@ -149,19 +174,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
assert num_experts % self.moe_ep_size == 0
|
assert num_experts % self.moe_ep_size == 0
|
||||||
self.num_local_experts = num_experts // self.moe_ep_size
|
self.num_local_experts = num_experts // self.moe_ep_size
|
||||||
|
|
||||||
if self.moe_ep_size > 1:
|
|
||||||
# TODO(ch-wan): support shared experts fusion
|
|
||||||
# Create a tensor of size num_experts filled with -1
|
|
||||||
self.expert_map_cpu = torch.full(
|
|
||||||
(self.num_experts,), -1, dtype=torch.int32, device="cpu"
|
|
||||||
)
|
|
||||||
# Create a expert map for the local experts
|
|
||||||
self.expert_map_cpu[
|
|
||||||
self.moe_ep_rank
|
|
||||||
* self.num_local_experts : (self.moe_ep_rank + 1)
|
|
||||||
* self.num_local_experts
|
|
||||||
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
|
||||||
|
|
||||||
assert intermediate_size % self.moe_tp_size == 0
|
assert intermediate_size % self.moe_tp_size == 0
|
||||||
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
@@ -219,7 +231,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
self.quant_method.create_moe_runner(self, self.moe_runner_config)
|
||||||
self.dispatcher = StandardDispatcher()
|
self.dispatcher = create_moe_dispatcher(self.moe_runner_config)
|
||||||
|
|
||||||
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
|
self.should_fuse_routed_scaling_factor_in_topk = isinstance(
|
||||||
self.quant_method, ModelOptNvFp4FusedMoEMethod
|
self.quant_method, ModelOptNvFp4FusedMoEMethod
|
||||||
@@ -453,9 +465,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
expert_data.copy_(loaded_weight)
|
expert_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
||||||
if self.expert_map_cpu is None:
|
start_idx = self.moe_ep_rank * self.num_local_experts
|
||||||
return expert_id
|
end_idx = (self.moe_ep_rank + 1) * self.num_local_experts
|
||||||
return self.expert_map_cpu[expert_id].item()
|
if start_idx <= expert_id < end_idx:
|
||||||
|
return expert_id - start_idx
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
def weight_loader(
|
def weight_loader(
|
||||||
self,
|
self,
|
||||||
@@ -804,32 +819,18 @@ class FusedMoE(torch.nn.Module):
|
|||||||
origin_hidden_states_dim = hidden_states.shape[-1]
|
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
|
|
||||||
if self.expert_map_cpu is not None and self.expert_map_gpu is None:
|
|
||||||
# If we are in EP mode, we need to move the expert map to GPU.
|
|
||||||
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
|
||||||
|
|
||||||
if self.expert_map_gpu is not None:
|
|
||||||
if TopKOutputChecker.format_is_standard(topk_output):
|
|
||||||
topk_output = topk_output._replace(
|
|
||||||
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
|
|
||||||
)
|
|
||||||
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
dispatch_output = self.dispatcher.dispatch(
|
dispatch_output = self.dispatcher.dispatch(
|
||||||
hidden_states=hidden_states, topk_output=topk_output
|
hidden_states=hidden_states, topk_output=topk_output
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: consider using symmetric memory
|
combine_input = self.run_moe_core(
|
||||||
combine_input = self.quant_method.apply(
|
|
||||||
layer=self,
|
|
||||||
dispatch_output=dispatch_output,
|
dispatch_output=dispatch_output,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_hidden_states = self.dispatcher.combine(combine_input)
|
final_hidden_states = self.dispatcher.combine(combine_input)
|
||||||
|
|
||||||
|
# TODO: should we add some conditions here?
|
||||||
final_hidden_states = final_hidden_states[
|
final_hidden_states = final_hidden_states[
|
||||||
..., :origin_hidden_states_dim
|
..., :origin_hidden_states_dim
|
||||||
].contiguous()
|
].contiguous()
|
||||||
@@ -839,6 +840,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
def run_moe_core(self, dispatch_output: DispatchOutput, **kwargs) -> CombineInput:
|
||||||
|
# TODO: consider using symmetric memory
|
||||||
|
return self.quant_method.apply(
|
||||||
|
layer=self,
|
||||||
|
dispatch_output=dispatch_output,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_expert_params_mapping(
|
def make_expert_params_mapping(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from sglang.srt.layers.moe.token_dispatcher.mooncake import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||||
StandardCombineInput,
|
StandardCombineInput,
|
||||||
|
StandardDispatcher,
|
||||||
StandardDispatchOutput,
|
StandardDispatchOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,6 +39,7 @@ __all__ = [
|
|||||||
"MooncakeCombineInput",
|
"MooncakeCombineInput",
|
||||||
"MooncakeDispatchOutput",
|
"MooncakeDispatchOutput",
|
||||||
"MooncakeEPDispatcher",
|
"MooncakeEPDispatcher",
|
||||||
|
"StandardDispatcher",
|
||||||
"StandardDispatchOutput",
|
"StandardDispatchOutput",
|
||||||
"StandardCombineInput",
|
"StandardCombineInput",
|
||||||
"DeepEPConfig",
|
"DeepEPConfig",
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class DispatchOutputFormat(Enum):
|
|||||||
class DispatchOutput(Protocol):
|
class DispatchOutput(Protocol):
|
||||||
"""Protocol for dispatch outputs in different formats."""
|
"""Protocol for dispatch outputs in different formats."""
|
||||||
|
|
||||||
# TODO: add hidden_states to the protocol
|
hidden_states: torch.Tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def format(self) -> DispatchOutputFormat: ...
|
def format(self) -> DispatchOutputFormat: ...
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
|||||||
|
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.layers import deep_gemm_wrapper
|
from sglang.srt.layers import deep_gemm_wrapper
|
||||||
|
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||||
BaseDispatcher,
|
BaseDispatcher,
|
||||||
BaseDispatcherConfig,
|
BaseDispatcherConfig,
|
||||||
@@ -15,6 +16,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
|||||||
DispatchOutput,
|
DispatchOutput,
|
||||||
DispatchOutputFormat,
|
DispatchOutputFormat,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.moe.utils import (
|
from sglang.srt.layers.moe.utils import (
|
||||||
DeepEPMode,
|
DeepEPMode,
|
||||||
get_deepep_config,
|
get_deepep_config,
|
||||||
@@ -51,8 +53,6 @@ from enum import Enum, IntEnum, auto
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
||||||
|
|
||||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -61,9 +61,9 @@ logger = logging.getLogger(__name__)
|
|||||||
class DeepEPNormalOutput(NamedTuple):
|
class DeepEPNormalOutput(NamedTuple):
|
||||||
"""DeepEP normal dispatch output."""
|
"""DeepEP normal dispatch output."""
|
||||||
|
|
||||||
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
hidden_states: torch.Tensor
|
||||||
# hidden_states_scale
|
hidden_states_scale: Optional[torch.Tensor]
|
||||||
topk_idx: torch.Tensor
|
topk_ids: torch.Tensor
|
||||||
topk_weights: torch.Tensor
|
topk_weights: torch.Tensor
|
||||||
num_recv_tokens_per_expert: List[int]
|
num_recv_tokens_per_expert: List[int]
|
||||||
|
|
||||||
@@ -75,8 +75,9 @@ class DeepEPNormalOutput(NamedTuple):
|
|||||||
class DeepEPLLOutput(NamedTuple):
|
class DeepEPLLOutput(NamedTuple):
|
||||||
"""DeepEP low latency dispatch output."""
|
"""DeepEP low latency dispatch output."""
|
||||||
|
|
||||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
hidden_states: torch.Tensor
|
||||||
topk_idx: torch.Tensor
|
hidden_states_scale: Optional[torch.Tensor]
|
||||||
|
topk_ids: torch.Tensor
|
||||||
topk_weights: torch.Tensor
|
topk_weights: torch.Tensor
|
||||||
masked_m: torch.Tensor
|
masked_m: torch.Tensor
|
||||||
expected_m: int
|
expected_m: int
|
||||||
@@ -314,9 +315,7 @@ class _DeepEPDispatcherImplBase:
|
|||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_global_scale: Optional[torch.Tensor],
|
topk_output: TopKOutput,
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -326,7 +325,7 @@ class _DeepEPDispatcherImplBase:
|
|||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
overlap_args: Optional["CombineOverlapArgs"],
|
overlap_args: Optional["CombineOverlapArgs"],
|
||||||
):
|
):
|
||||||
@@ -345,15 +344,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
|
|
||||||
self.async_finish = async_finish
|
self.async_finish = async_finish
|
||||||
self.src2dst = None
|
self.src2dst = None
|
||||||
|
self.quant_config = {}
|
||||||
|
|
||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_global_scale: Optional[torch.Tensor],
|
topk_output: TopKOutput,
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
||||||
|
topk_ids = topk_ids.to(torch.int64)
|
||||||
if (
|
if (
|
||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
and not get_moe_runner_backend().is_cutlass()
|
and not get_moe_runner_backend().is_cutlass()
|
||||||
@@ -367,25 +366,35 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||||
)
|
)
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
return hidden_states, topk_idx, topk_weights, previous_event
|
return hidden_states, topk_ids, topk_weights, previous_event
|
||||||
|
|
||||||
def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
|
def dispatch_b(self, hidden_states, topk_ids, topk_weights, previous_event):
|
||||||
(
|
(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
num_recv_tokens_per_expert,
|
num_recv_tokens_per_expert,
|
||||||
event,
|
event,
|
||||||
) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
|
) = self._dispatch_core(hidden_states, topk_ids, topk_weights, previous_event)
|
||||||
event.current_stream_wait() if self.async_finish else ()
|
event.current_stream_wait() if self.async_finish else ()
|
||||||
|
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
hidden_states, hidden_states_scale = hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states_scale = None
|
||||||
|
|
||||||
return DeepEPNormalOutput(
|
return DeepEPNormalOutput(
|
||||||
hidden_states, topk_idx, topk_weights, num_recv_tokens_per_expert
|
hidden_states,
|
||||||
|
hidden_states_scale,
|
||||||
|
topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
num_recv_tokens_per_expert,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _dispatch_core(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
previous_event,
|
previous_event,
|
||||||
):
|
):
|
||||||
@@ -397,7 +406,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
is_token_in_rank,
|
is_token_in_rank,
|
||||||
previous_event,
|
previous_event,
|
||||||
) = buffer.get_dispatch_layout(
|
) = buffer.get_dispatch_layout(
|
||||||
topk_idx,
|
topk_ids,
|
||||||
self.num_experts,
|
self.num_experts,
|
||||||
previous_event=previous_event,
|
previous_event=previous_event,
|
||||||
async_finish=self.async_finish,
|
async_finish=self.async_finish,
|
||||||
@@ -409,14 +418,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
|
|
||||||
(
|
(
|
||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_ids,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
num_recv_tokens_per_expert,
|
num_recv_tokens_per_expert,
|
||||||
self.handle,
|
self.handle,
|
||||||
event,
|
event,
|
||||||
) = buffer.dispatch(
|
) = buffer.dispatch(
|
||||||
x,
|
x,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_ids,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
num_tokens_per_rank=num_tokens_per_rank,
|
num_tokens_per_rank=num_tokens_per_rank,
|
||||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||||
@@ -437,7 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
recv_x,
|
recv_x,
|
||||||
recv_topk_idx,
|
recv_topk_ids,
|
||||||
recv_topk_weights,
|
recv_topk_weights,
|
||||||
num_recv_tokens_per_expert,
|
num_recv_tokens_per_expert,
|
||||||
event,
|
event,
|
||||||
@@ -446,40 +455,16 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
overlap_args: Optional["CombineOverlapArgs"],
|
overlap_args: Optional["CombineOverlapArgs"],
|
||||||
):
|
):
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
||||||
deepep_post_reorder_triton_kernel,
|
|
||||||
)
|
|
||||||
|
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||||
output = hidden_states
|
output = hidden_states
|
||||||
else:
|
else:
|
||||||
if hidden_states.shape[0] > 0:
|
raise NotImplementedError() # triton runner was supported but it's temporarily disabled
|
||||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
|
||||||
output = torch.empty(
|
|
||||||
(num_tokens, hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
|
||||||
hidden_states,
|
|
||||||
output,
|
|
||||||
self.src2dst,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
self.router_topk,
|
|
||||||
hidden_states.shape[1],
|
|
||||||
BLOCK_SIZE=512,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output = torch.zeros(
|
|
||||||
(0, hidden_states.shape[1]),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
return output, previous_event
|
return output, previous_event
|
||||||
|
|
||||||
@@ -514,6 +499,9 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
self.num_experts,
|
self.num_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_quant_config(self, quant_config: dict):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
|
||||||
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||||
def __init__(self, return_recv_hook: bool, **kwargs):
|
def __init__(self, return_recv_hook: bool, **kwargs):
|
||||||
@@ -525,28 +513,27 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
"""
|
"""
|
||||||
self.return_recv_hook = return_recv_hook
|
self.return_recv_hook = return_recv_hook
|
||||||
self.device_module = torch.get_device_module()
|
self.device_module = torch.get_device_module()
|
||||||
|
self.quant_config = {}
|
||||||
|
|
||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_global_scale: Optional[torch.Tensor],
|
topk_output: TopKOutput,
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
buffer = self._get_buffer()
|
buffer = self._get_buffer()
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids
|
||||||
|
topk_ids = topk_ids.to(torch.int64)
|
||||||
expected_m = (
|
expected_m = (
|
||||||
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
|
||||||
+ self.num_experts
|
+ self.num_experts
|
||||||
) // self.num_experts
|
) // self.num_experts
|
||||||
hidden_states, masked_m, event, hook = self._dispatch_core(
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
input_global_scale,
|
topk_ids,
|
||||||
topk_idx,
|
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
@@ -557,7 +544,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
def dispatch_b(
|
def dispatch_b(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
@@ -570,9 +557,15 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
masked_m
|
masked_m
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(hidden_states, tuple):
|
||||||
|
hidden_states, hidden_states_scale = hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states_scale = None
|
||||||
|
|
||||||
deepep_output = DeepEPLLOutput(
|
deepep_output = DeepEPLLOutput(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
hidden_states_scale,
|
||||||
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
@@ -582,10 +575,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
def _dispatch_core(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_global_scale: Optional[torch.Tensor],
|
topk_ids: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
use_nvfp4 = use_fp8 = False
|
use_nvfp4 = use_fp8 = False
|
||||||
|
input_global_scale = self.quant_config.get("input_global_scale", None)
|
||||||
if input_global_scale is not None:
|
if input_global_scale is not None:
|
||||||
use_nvfp4 = True
|
use_nvfp4 = True
|
||||||
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
|
elif not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"):
|
||||||
@@ -595,7 +588,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
||||||
buffer.low_latency_dispatch(
|
buffer.low_latency_dispatch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
self.num_max_dispatch_tokens_per_rank,
|
self.num_max_dispatch_tokens_per_rank,
|
||||||
self.num_experts,
|
self.num_experts,
|
||||||
use_fp8=use_fp8,
|
use_fp8=use_fp8,
|
||||||
@@ -618,13 +611,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
overlap_args: Optional["CombineOverlapArgs"],
|
overlap_args: Optional["CombineOverlapArgs"],
|
||||||
):
|
):
|
||||||
hidden_states, event, hook = self._combine_core(
|
hidden_states, event, hook = self._combine_core(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
overlap_args=overlap_args,
|
overlap_args=overlap_args,
|
||||||
)
|
)
|
||||||
@@ -644,7 +637,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
def _combine_core(
|
def _combine_core(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
overlap_args: Optional["CombineOverlapArgs"],
|
overlap_args: Optional["CombineOverlapArgs"],
|
||||||
):
|
):
|
||||||
@@ -658,7 +651,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
with ctx:
|
with ctx:
|
||||||
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_ids,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
handle=self.handle,
|
handle=self.handle,
|
||||||
async_finish=not self.return_recv_hook,
|
async_finish=not self.return_recv_hook,
|
||||||
@@ -688,6 +681,9 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|||||||
self.num_experts,
|
self.num_experts,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_quant_config(self, quant_config: dict):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _Stage(Enum):
|
class _Stage(Enum):
|
||||||
@@ -745,25 +741,20 @@ class DeepEPDispatcher(BaseDispatcher):
|
|||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_global_scale: Optional[torch.Tensor],
|
topk_output: TopKOutput,
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
):
|
):
|
||||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||||
inner_state = self._get_impl(forward_batch).dispatch_a(
|
inner_state = self._get_impl().dispatch_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
input_global_scale=input_global_scale,
|
topk_output=topk_output,
|
||||||
topk_idx=topk_idx,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
)
|
)
|
||||||
self._dispatch_intermediate_state = forward_batch, inner_state
|
self._dispatch_intermediate_state = inner_state
|
||||||
|
|
||||||
def dispatch_b(self):
|
def dispatch_b(self):
|
||||||
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||||
forward_batch, inner_state = self._dispatch_intermediate_state
|
inner_state = self._dispatch_intermediate_state
|
||||||
del self._dispatch_intermediate_state
|
del self._dispatch_intermediate_state
|
||||||
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
return self._get_impl().dispatch_b(*inner_state)
|
||||||
|
|
||||||
def combine(self, *args, **kwargs) -> Tuple:
|
def combine(self, *args, **kwargs) -> Tuple:
|
||||||
self.combine_a(*args, **kwargs)
|
self.combine_a(*args, **kwargs)
|
||||||
@@ -773,30 +764,28 @@ class DeepEPDispatcher(BaseDispatcher):
|
|||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
overlap_args: Optional["CombineOverlapArgs"] = None,
|
overlap_args: Optional["CombineOverlapArgs"] = None,
|
||||||
):
|
):
|
||||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||||
inner_state = self._get_impl(forward_batch).combine_a(
|
inner_state = self._get_impl().combine_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_ids=topk_ids,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
overlap_args=overlap_args,
|
overlap_args=overlap_args,
|
||||||
)
|
)
|
||||||
self._combine_intermediate_state = forward_batch, inner_state
|
self._combine_intermediate_state = inner_state
|
||||||
|
|
||||||
def combine_b(self):
|
def combine_b(self):
|
||||||
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||||
forward_batch, inner_state = self._combine_intermediate_state
|
inner_state = self._combine_intermediate_state
|
||||||
del self._combine_intermediate_state
|
del self._combine_intermediate_state
|
||||||
return self._get_impl(forward_batch).combine_b(*inner_state)
|
return self._get_impl().combine_b(*inner_state)
|
||||||
|
|
||||||
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
|
def _get_impl(self) -> _DeepEPDispatcherImplBase:
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
is_extend_in_batch = get_is_extend_in_batch()
|
||||||
forward_batch.is_extend_in_batch
|
resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
|
||||||
)
|
|
||||||
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
||||||
return self._normal_dispatcher
|
return self._normal_dispatcher
|
||||||
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
||||||
@@ -807,3 +796,9 @@ class DeepEPDispatcher(BaseDispatcher):
|
|||||||
def _update_stage(self, old_stage, new_stage):
|
def _update_stage(self, old_stage, new_stage):
|
||||||
assert self._stage == old_stage
|
assert self._stage == old_stage
|
||||||
self._stage = new_stage
|
self._stage = new_stage
|
||||||
|
|
||||||
|
def set_quant_config(self, quant_config: dict):
|
||||||
|
if self.deepep_mode.enable_low_latency():
|
||||||
|
self._low_latency_dispatcher.set_quant_config(quant_config)
|
||||||
|
if self.deepep_mode.enable_normal():
|
||||||
|
self._normal_dispatcher.set_quant_config(quant_config)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
from typing import NamedTuple, Optional, Tuple
|
from typing import NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
|
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||||
BaseDispatcher,
|
BaseDispatcher,
|
||||||
CombineInput,
|
CombineInput,
|
||||||
@@ -12,6 +13,7 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
|||||||
DispatchOutput,
|
DispatchOutput,
|
||||||
DispatchOutputFormat,
|
DispatchOutputFormat,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||||
from sglang.srt.utils import get_int_env_var
|
from sglang.srt.utils import get_int_env_var
|
||||||
|
|
||||||
@@ -27,16 +29,15 @@ from enum import Enum, auto
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MooncakeDispatchOutput(NamedTuple):
|
class MooncakeDispatchOutput(NamedTuple):
|
||||||
"""Mooncake EP dispatch output."""
|
"""Mooncake EP dispatch output."""
|
||||||
|
|
||||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor]
|
hidden_states: torch.Tensor
|
||||||
topk_idx: torch.Tensor
|
hidden_states_scale: torch.Tensor
|
||||||
|
topk_ids: torch.Tensor
|
||||||
topk_weights: torch.Tensor
|
topk_weights: torch.Tensor
|
||||||
masked_m: torch.Tensor
|
masked_m: torch.Tensor
|
||||||
expected_m: int
|
expected_m: int
|
||||||
@@ -164,23 +165,23 @@ class _MooncakeEPDispatcherImpl:
|
|||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_output: TopKOutput,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
):
|
):
|
||||||
|
topk_ids, topk_weights = topk_output.topk_ids, topk_output.topk_weights
|
||||||
buffer = self._get_buffer()
|
buffer = self._get_buffer()
|
||||||
topk_idx = topk_idx.to(torch.int64)
|
topk_ids = topk_ids.to(torch.int64)
|
||||||
expected_m = (
|
expected_m = (
|
||||||
hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1]
|
hidden_states.shape[0] * buffer.group_size * topk_ids.shape[1]
|
||||||
+ self.num_experts
|
+ self.num_experts
|
||||||
) // self.num_experts
|
) // self.num_experts
|
||||||
hidden_states, masked_m, event, hook = self._dispatch_core(
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
use_fp8=True,
|
use_fp8=True,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
@@ -191,7 +192,7 @@ class _MooncakeEPDispatcherImpl:
|
|||||||
def dispatch_b(
|
def dispatch_b(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
@@ -206,7 +207,7 @@ class _MooncakeEPDispatcherImpl:
|
|||||||
|
|
||||||
return MooncakeDispatchOutput(
|
return MooncakeDispatchOutput(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
masked_m,
|
masked_m,
|
||||||
expected_m,
|
expected_m,
|
||||||
@@ -215,14 +216,14 @@ class _MooncakeEPDispatcherImpl:
|
|||||||
def _dispatch_core(
|
def _dispatch_core(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
):
|
):
|
||||||
buffer = self._get_buffer()
|
buffer = self._get_buffer()
|
||||||
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
||||||
buffer.dispatch(
|
buffer.dispatch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
self.active_ranks,
|
self.active_ranks,
|
||||||
self.num_max_dispatch_tokens_per_rank,
|
self.num_max_dispatch_tokens_per_rank,
|
||||||
self.num_experts,
|
self.num_experts,
|
||||||
@@ -237,12 +238,12 @@ class _MooncakeEPDispatcherImpl:
|
|||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
hidden_states, event, hook = self._combine_core(
|
hidden_states, event, hook = self._combine_core(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
)
|
)
|
||||||
return hidden_states, event, hook
|
return hidden_states, event, hook
|
||||||
@@ -254,13 +255,13 @@ class _MooncakeEPDispatcherImpl:
|
|||||||
def _combine_core(
|
def _combine_core(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
buffer = self._get_buffer()
|
buffer = self._get_buffer()
|
||||||
combined_hidden_states, event, hook = buffer.combine(
|
combined_hidden_states, event, hook = buffer.combine(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_ids,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
self.active_ranks,
|
self.active_ranks,
|
||||||
-1 if self.first_execution else self.timeout_us,
|
-1 if self.first_execution else self.timeout_us,
|
||||||
@@ -332,24 +333,20 @@ class MooncakeEPDispatcher(BaseDispatcher):
|
|||||||
def dispatch_a(
|
def dispatch_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_global_scale: Optional[torch.Tensor],
|
topk_output: TopKOutput,
|
||||||
topk_idx: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
):
|
):
|
||||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||||
inner_state = self._get_impl(forward_batch).dispatch_a(
|
inner_state = self._get_impl().dispatch_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_output=topk_output,
|
||||||
topk_weights=topk_weights,
|
|
||||||
)
|
)
|
||||||
self._dispatch_intermediate_state = forward_batch, inner_state
|
self._dispatch_intermediate_state = inner_state
|
||||||
|
|
||||||
def dispatch_b(self):
|
def dispatch_b(self):
|
||||||
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||||
forward_batch, inner_state = self._dispatch_intermediate_state
|
inner_state = self._dispatch_intermediate_state
|
||||||
del self._dispatch_intermediate_state
|
del self._dispatch_intermediate_state
|
||||||
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
return self._get_impl().dispatch_b(*inner_state)
|
||||||
|
|
||||||
def combine(self, *args, **kwargs) -> Tuple:
|
def combine(self, *args, **kwargs) -> Tuple:
|
||||||
self.combine_a(*args, **kwargs)
|
self.combine_a(*args, **kwargs)
|
||||||
@@ -359,29 +356,27 @@ class MooncakeEPDispatcher(BaseDispatcher):
|
|||||||
def combine_a(
|
def combine_a(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
overlap_args: Optional = None,
|
overlap_args: Optional = None,
|
||||||
):
|
):
|
||||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||||
inner_state = self._get_impl(forward_batch).combine_a(
|
inner_state = self._get_impl().combine_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_ids=topk_ids,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
)
|
)
|
||||||
self._combine_intermediate_state = forward_batch, inner_state
|
self._combine_intermediate_state = inner_state
|
||||||
|
|
||||||
def combine_b(self):
|
def combine_b(self):
|
||||||
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||||
forward_batch, inner_state = self._combine_intermediate_state
|
inner_state = self._combine_intermediate_state
|
||||||
del self._combine_intermediate_state
|
del self._combine_intermediate_state
|
||||||
return self._get_impl(forward_batch).combine_b(*inner_state)
|
return self._get_impl().combine_b(*inner_state)
|
||||||
|
|
||||||
def _get_impl(self, forward_batch: ForwardBatch) -> _MooncakeEPDispatcherImpl:
|
def _get_impl(self) -> _MooncakeEPDispatcherImpl:
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
is_extend_in_batch = get_is_extend_in_batch()
|
||||||
forward_batch.is_extend_in_batch
|
resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch)
|
||||||
)
|
|
||||||
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
||||||
@@ -392,3 +387,6 @@ class MooncakeEPDispatcher(BaseDispatcher):
|
|||||||
def _update_stage(self, old_stage, new_stage):
|
def _update_stage(self, old_stage, new_stage):
|
||||||
assert self._stage == old_stage
|
assert self._stage == old_stage
|
||||||
self._stage = new_stage
|
self._stage = new_stage
|
||||||
|
|
||||||
|
def set_quant_config(self, quant_config: dict):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, NamedTuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.distributed import (
|
||||||
|
get_moe_expert_parallel_rank,
|
||||||
|
get_moe_expert_parallel_world_size,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||||
BaseDispatcher,
|
BaseDispatcher,
|
||||||
CombineInput,
|
CombineInput,
|
||||||
@@ -11,6 +16,8 @@ from sglang.srt.layers.moe.token_dispatcher.base import (
|
|||||||
DispatchOutput,
|
DispatchOutput,
|
||||||
DispatchOutputFormat,
|
DispatchOutputFormat,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
|
||||||
|
from sglang.srt.layers.moe.utils import get_moe_runner_backend
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
@@ -45,9 +52,45 @@ assert isinstance(StandardCombineInput, CombineInput)
|
|||||||
|
|
||||||
class StandardDispatcher(BaseDispatcher):
|
class StandardDispatcher(BaseDispatcher):
|
||||||
|
|
||||||
|
def __init__(self, moe_runner_config: MoeRunnerConfig):
|
||||||
|
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
||||||
|
self.enable_flashinfer_cutlass_moe = (
|
||||||
|
get_moe_runner_backend().is_flashinfer_cutlass()
|
||||||
|
)
|
||||||
|
self.num_experts = moe_runner_config.num_experts
|
||||||
|
self.num_local_experts = moe_runner_config.num_local_experts
|
||||||
|
self.moe_ep_rank = get_moe_expert_parallel_rank()
|
||||||
|
self.local_expert_mapping = None
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor, topk_output: TopKOutput
|
self, hidden_states: torch.Tensor, topk_output: TopKOutput
|
||||||
) -> DispatchOutput:
|
) -> DispatchOutput:
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.moe_ep_size > 1
|
||||||
|
and not self.enable_flashinfer_cutlass_moe
|
||||||
|
and TopKOutputChecker.format_is_standard(topk_output)
|
||||||
|
):
|
||||||
|
if self.local_expert_mapping is None:
|
||||||
|
self.local_expert_mapping = torch.full(
|
||||||
|
(self.num_experts,), -1, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
self.local_expert_mapping[
|
||||||
|
self.moe_ep_rank
|
||||||
|
* self.num_local_experts : (self.moe_ep_rank + 1)
|
||||||
|
* self.num_local_experts
|
||||||
|
] = torch.arange(
|
||||||
|
0, self.num_local_experts, dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.local_expert_mapping is not None:
|
||||||
|
if TopKOutputChecker.format_is_standard(topk_output):
|
||||||
|
topk_output = topk_output._replace(
|
||||||
|
topk_ids=self.local_expert_mapping[topk_output.topk_ids]
|
||||||
|
)
|
||||||
|
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
return StandardDispatchOutput(
|
return StandardDispatchOutput(
|
||||||
hidden_states=hidden_states, topk_output=topk_output
|
hidden_states=hidden_states, topk_output=topk_output
|
||||||
)
|
)
|
||||||
@@ -59,3 +102,6 @@ class StandardDispatcher(BaseDispatcher):
|
|||||||
# TODO: this branch should be removed in the future
|
# TODO: this branch should be removed in the future
|
||||||
assert isinstance(combine_input, torch.Tensor)
|
assert isinstance(combine_input, torch.Tensor)
|
||||||
return combine_input
|
return combine_input
|
||||||
|
|
||||||
|
def set_quant_config(self, quant_config: dict):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -365,9 +365,10 @@ class TopK(CustomOp):
|
|||||||
def empty_topk_output(self, device: torch.device) -> TopKOutput:
|
def empty_topk_output(self, device: torch.device) -> TopKOutput:
|
||||||
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
|
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
|
||||||
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
|
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
|
||||||
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
|
topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device)
|
||||||
|
# FIXME: router_logits should be of size (0, num_experts)
|
||||||
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
|
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
|
||||||
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
|
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------- TopK implementation -------------------------------------
|
# ------------------------------- TopK implementation -------------------------------------
|
||||||
|
|||||||
@@ -1244,6 +1244,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|||||||
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
(1 / w2_input_scale).to(torch.float32), requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
layer.dispatcher.set_quant_config(
|
||||||
|
{"input_global_scale": layer.w13_input_scale_quant}
|
||||||
|
)
|
||||||
|
|
||||||
# Validate weight scales
|
# Validate weight scales
|
||||||
for name, weight_scale in [
|
for name, weight_scale in [
|
||||||
("w13", layer.w13_weight_scale),
|
("w13", layer.w13_weight_scale),
|
||||||
|
|||||||
@@ -339,7 +339,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
hidden_states, topk_idx, topk_weights = (
|
hidden_states, topk_idx, topk_weights = (
|
||||||
dispatch_output.hidden_states,
|
dispatch_output.hidden_states,
|
||||||
dispatch_output.topk_idx,
|
dispatch_output.topk_ids,
|
||||||
dispatch_output.topk_weights,
|
dispatch_output.topk_weights,
|
||||||
)
|
)
|
||||||
if isinstance(hidden_states, tuple):
|
if isinstance(hidden_states, tuple):
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
set_dp_buffer_len,
|
set_dp_buffer_len,
|
||||||
|
set_is_extend_in_batch,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
|
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
|
||||||
@@ -639,6 +640,7 @@ class CudaGraphRunner:
|
|||||||
# Clean intermediate result cache for DP attention
|
# Clean intermediate result cache for DP attention
|
||||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||||
|
set_is_extend_in_batch(False)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
get_attention_dp_rank,
|
get_attention_dp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
set_dp_buffer_len,
|
set_dp_buffer_len,
|
||||||
|
set_is_extend_in_batch,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
|
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
|
||||||
|
|
||||||
@@ -688,6 +689,7 @@ class ForwardBatch:
|
|||||||
|
|
||||||
self.global_dp_buffer_len = buffer_len
|
self.global_dp_buffer_len = buffer_len
|
||||||
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
|
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
|
||||||
|
set_is_extend_in_batch(self.is_extend_in_batch)
|
||||||
|
|
||||||
bs = self.batch_size
|
bs = self.batch_size
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from sglang.srt.layers.dp_attention import (
|
|||||||
get_attention_tp_rank,
|
get_attention_tp_rank,
|
||||||
get_attention_tp_size,
|
get_attention_tp_size,
|
||||||
set_dp_buffer_len,
|
set_dp_buffer_len,
|
||||||
|
set_is_extend_in_batch,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
||||||
@@ -377,6 +378,9 @@ class PiecewiseCudaGraphRunner:
|
|||||||
# Clean intermediate result cache for DP attention
|
# Clean intermediate result cache for DP attention
|
||||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||||
|
# FIXME: the implementation is hacky. `is_extend_in_batch`` is for determining the deepep mode.
|
||||||
|
# It is True in this context but we need to set it to use low latency deepep mode.
|
||||||
|
set_is_extend_in_batch(False)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
with set_forward_context(forward_batch, self.attention_layers):
|
with set_forward_context(forward_batch, self.attention_layers):
|
||||||
|
|||||||
@@ -380,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|||||||
if self.num_shared_experts > 0:
|
if self.num_shared_experts > 0:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
|
|
||||||
topk_weights, topk_idx, _ = self.topk(
|
topk_output = self.topk(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||||
@@ -389,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_idx = torch.full(
|
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
||||||
)
|
|
||||||
topk_weights = torch.empty(
|
|
||||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.ep_size > 1:
|
|
||||||
(
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
reorder_topk_ids,
|
|
||||||
num_recv_tokens_per_expert,
|
|
||||||
seg_indptr,
|
|
||||||
masked_m,
|
|
||||||
expected_m,
|
|
||||||
) = self.deepep_dispatcher.dispatch(
|
|
||||||
hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_output=topk_output,
|
||||||
topk_weights=topk_weights,
|
|
||||||
reorder_topk_ids=reorder_topk_ids,
|
|
||||||
seg_indptr=seg_indptr,
|
|
||||||
masked_m=masked_m,
|
|
||||||
expected_m=expected_m,
|
|
||||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
|
||||||
final_hidden_states = self.deepep_dispatcher.combine(
|
|
||||||
final_hidden_states,
|
|
||||||
topk_idx,
|
|
||||||
topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states += shared_output
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ from sglang.srt.layers.linear import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe import (
|
from sglang.srt.layers.moe import (
|
||||||
get_deepep_mode,
|
|
||||||
get_moe_a2a_backend,
|
get_moe_a2a_backend,
|
||||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||||
should_use_flashinfer_trtllm_moe,
|
should_use_flashinfer_trtllm_moe,
|
||||||
@@ -112,10 +111,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.server_args import get_global_server_args
|
from sglang.srt.server_args import get_global_server_args
|
||||||
from sglang.srt.single_batch_overlap import SboFlags
|
from sglang.srt.single_batch_overlap import SboFlags
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.two_batch_overlap import (
|
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
||||||
MaybeTboDeepEPDispatcher,
|
|
||||||
model_forward_maybe_tbo,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
BumpAllocator,
|
BumpAllocator,
|
||||||
LazyValue,
|
LazyValue,
|
||||||
@@ -649,19 +645,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
|
||||||
group=parallel_state.get_tp_group().device_group,
|
|
||||||
router_topk=self.top_k,
|
|
||||||
permute_fusion=True,
|
|
||||||
num_experts=self.num_experts,
|
|
||||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
params_dtype=config.torch_dtype,
|
|
||||||
deepep_mode=get_deepep_mode(),
|
|
||||||
async_finish=True,
|
|
||||||
return_recv_hook=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._enable_a2a_moe = (
|
self._enable_a2a_moe = (
|
||||||
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
||||||
)
|
)
|
||||||
@@ -874,7 +857,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
if not self._fuse_shared_experts_inside_sbo:
|
if not self._fuse_shared_experts_inside_sbo:
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
topk_weights, topk_idx, _ = self.topk(
|
topk_output = self.topk(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||||
@@ -883,9 +866,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||||
hidden_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._fuse_shared_experts_inside_sbo:
|
if self._fuse_shared_experts_inside_sbo:
|
||||||
shared_output = None
|
shared_output = None
|
||||||
@@ -896,9 +877,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_output=topk_output,
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
**(
|
**(
|
||||||
dict(
|
dict(
|
||||||
forward_shared_experts=_forward_shared_experts_and_put_results,
|
forward_shared_experts=_forward_shared_experts_and_put_results,
|
||||||
@@ -960,7 +939,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
with get_global_expert_distribution_recorder().with_current_layer(
|
with get_global_expert_distribution_recorder().with_current_layer(
|
||||||
self.layer_id
|
self.layer_id
|
||||||
):
|
):
|
||||||
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
state.topk_output = self.topk(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||||
@@ -969,21 +948,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state.topk_idx_local = torch.full(
|
state.topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
||||||
)
|
|
||||||
state.topk_weights_local = torch.empty(
|
|
||||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def op_dispatch_a(self, state):
|
def op_dispatch_a(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
self.experts.deepep_dispatcher.dispatch_a(
|
self.experts.dispatcher.dispatch_a(
|
||||||
hidden_states=state.hidden_states_mlp_input,
|
hidden_states=state.hidden_states_mlp_input,
|
||||||
input_global_scale=None,
|
topk_output=state.pop("topk_output"),
|
||||||
topk_idx=state.pop("topk_idx_local"),
|
|
||||||
topk_weights=state.pop("topk_weights_local"),
|
|
||||||
forward_batch=state.forward_batch,
|
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -992,32 +963,29 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
with get_global_expert_distribution_recorder().with_current_layer(
|
with get_global_expert_distribution_recorder().with_current_layer(
|
||||||
self.layer_id
|
self.layer_id
|
||||||
):
|
):
|
||||||
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
state.dispatch_output = self.experts.dispatcher.dispatch_b(
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_experts(self, state):
|
def op_experts(self, state):
|
||||||
state.hidden_states_experts_output = self.experts.moe_impl(
|
state.hidden_states_experts_output = self.experts.run_moe_core(
|
||||||
dispatch_output=state.dispatch_output,
|
dispatch_output=state.dispatch_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_combine_a(self, state):
|
def op_combine_a(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
self.experts.deepep_dispatcher.combine_a(
|
self.experts.dispatcher.combine_a(
|
||||||
hidden_states=state.pop("hidden_states_experts_output"),
|
hidden_states=state.pop("hidden_states_experts_output"),
|
||||||
topk_idx=state.dispatch_output.topk_idx,
|
topk_ids=state.dispatch_output.topk_ids,
|
||||||
topk_weights=state.dispatch_output.topk_weights,
|
topk_weights=state.dispatch_output.topk_weights,
|
||||||
forward_batch=state.forward_batch,
|
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
state.pop("dispatch_output")
|
state.pop("dispatch_output")
|
||||||
|
|
||||||
def op_combine_b(self, state):
|
def op_combine_b(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
state.hidden_states_after_combine = (
|
state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
|
||||||
self.experts.deepep_dispatcher.combine_b(
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_output(self, state):
|
def op_output(self, state):
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from sglang.srt.distributed import (
|
|||||||
get_pp_group,
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
parallel_state,
|
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
@@ -49,7 +48,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
|
from sglang.srt.layers.moe import get_moe_a2a_backend
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
@@ -71,7 +70,6 @@ from sglang.srt.models.deepseek_v2 import (
|
|||||||
DeepseekV2MoE,
|
DeepseekV2MoE,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import get_global_server_args
|
from sglang.srt.server_args import get_global_server_args
|
||||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
BumpAllocator,
|
BumpAllocator,
|
||||||
LazyValue,
|
LazyValue,
|
||||||
@@ -477,19 +475,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
|
||||||
group=parallel_state.get_tp_group().device_group,
|
|
||||||
router_topk=self.top_k,
|
|
||||||
permute_fusion=True,
|
|
||||||
num_experts=self.num_experts,
|
|
||||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
params_dtype=config.torch_dtype,
|
|
||||||
deepep_mode=get_deepep_mode(),
|
|
||||||
async_finish=True,
|
|
||||||
return_recv_hook=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._enable_a2a_moe = (
|
self._enable_a2a_moe = (
|
||||||
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
topk_weights, topk_idx, _ = self.topk(
|
topk_output = self.topk(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||||
@@ -228,14 +228,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||||
hidden_states.device
|
|
||||||
)
|
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_output=topk_output,
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
topk_weights, topk_idx, _ = self.topk(
|
topk_output = self.topk(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||||
@@ -189,17 +189,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_idx = torch.full(
|
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
||||||
)
|
|
||||||
topk_weights = torch.empty(
|
|
||||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_output=topk_output,
|
||||||
topk_weights=topk_weights,
|
|
||||||
forward_batch=forward_batch,
|
|
||||||
)
|
)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
@@ -219,7 +212,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
with get_global_expert_distribution_recorder().with_current_layer(
|
with get_global_expert_distribution_recorder().with_current_layer(
|
||||||
self.layer_id
|
self.layer_id
|
||||||
):
|
):
|
||||||
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
state.topk_output = self.topk(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||||
@@ -228,20 +221,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state.topk_idx_local = torch.full(
|
state.topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
|
||||||
)
|
|
||||||
state.topk_weights_local = torch.empty(
|
|
||||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
def op_dispatch_a(self, state):
|
def op_dispatch_a(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
self.experts.deepep_dispatcher.dispatch_a(
|
self.experts.dispatcher.dispatch_a(
|
||||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||||
topk_idx=state.pop("topk_idx_local"),
|
topk_output=state.pop("topk_output"),
|
||||||
topk_weights=state.pop("topk_weights_local"),
|
|
||||||
forward_batch=state.forward_batch,
|
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -250,32 +236,29 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
with get_global_expert_distribution_recorder().with_current_layer(
|
with get_global_expert_distribution_recorder().with_current_layer(
|
||||||
self.layer_id
|
self.layer_id
|
||||||
):
|
):
|
||||||
state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
|
state.dispatch_output = self.experts.dispatcher.dispatch_b(
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_experts(self, state):
|
def op_experts(self, state):
|
||||||
state.hidden_states_experts_output = self.experts.moe_impl(
|
state.hidden_states_experts_output = self.experts.run_moe_core(
|
||||||
dispatch_output=state.dispatch_output,
|
dispatch_output=state.dispatch_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_combine_a(self, state):
|
def op_combine_a(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
self.experts.deepep_dispatcher.combine_a(
|
self.experts.dispatcher.combine_a(
|
||||||
hidden_states=state.pop("hidden_states_experts_output"),
|
hidden_states=state.pop("hidden_states_experts_output"),
|
||||||
topk_idx=state.dispatch_output.topk_idx,
|
topk_ids=state.dispatch_output.topk_ids,
|
||||||
topk_weights=state.dispatch_output.topk_weights,
|
topk_weights=state.dispatch_output.topk_weights,
|
||||||
forward_batch=state.forward_batch,
|
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
state.pop("dispatch_output")
|
state.pop("dispatch_output")
|
||||||
|
|
||||||
def op_combine_b(self, state):
|
def op_combine_b(self, state):
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
state.hidden_states_after_combine = (
|
state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
|
||||||
self.experts.deepep_dispatcher.combine_b(
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_output(self, state):
|
def op_output(self, state):
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
|
# Copyright 2025 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
@@ -5,12 +21,12 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.layers import deep_gemm_wrapper
|
from sglang.srt.layers import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
||||||
from sglang.srt.utils import get_int_env_var
|
from sglang.srt.utils import get_int_env_var
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
|
|
||||||
class SboFlags:
|
class SboFlags:
|
||||||
@@ -54,23 +70,22 @@ class DownGemmOverlapArgs:
|
|||||||
|
|
||||||
def execute_sbo(
|
def execute_sbo(
|
||||||
forward_shared_experts: Callable[[], Any],
|
forward_shared_experts: Callable[[], Any],
|
||||||
experts: "DeepEPMoE",
|
experts: FusedMoE,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_output: TopKOutput,
|
||||||
topk_weights: torch.Tensor,
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
forward_batch: ForwardBatch,
|
|
||||||
alt_stream: Optional = None,
|
|
||||||
disable_sbo: bool = False,
|
disable_sbo: bool = False,
|
||||||
):
|
):
|
||||||
dispatch_output = experts.dispatch(
|
|
||||||
hidden_states, topk_idx, topk_weights, forward_batch
|
dispatch_output = experts.dispatcher.dispatch(
|
||||||
|
hidden_states=hidden_states, topk_output=topk_output
|
||||||
)
|
)
|
||||||
|
|
||||||
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
||||||
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
|
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = experts.moe_impl(
|
hidden_states = experts.run_moe_core(
|
||||||
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
||||||
)
|
)
|
||||||
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
||||||
@@ -83,11 +98,10 @@ def execute_sbo(
|
|||||||
):
|
):
|
||||||
forward_shared_experts()
|
forward_shared_experts()
|
||||||
|
|
||||||
hidden_states = experts.combine(
|
hidden_states = experts.dispatcher.combine(
|
||||||
hidden_states,
|
hidden_states=hidden_states,
|
||||||
dispatch_output.topk_idx,
|
topk_ids=dispatch_output.topk_ids,
|
||||||
dispatch_output.topk_weights,
|
topk_weights=dispatch_output.topk_weights,
|
||||||
forward_batch,
|
|
||||||
overlap_args=combine_overlap_args,
|
overlap_args=combine_overlap_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -101,9 +115,7 @@ def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo):
|
|||||||
):
|
):
|
||||||
return None, None, {}
|
return None, None, {}
|
||||||
|
|
||||||
hidden_states = dispatch_output.hidden_states_fp8
|
hidden_states = dispatch_output.hidden_states
|
||||||
if isinstance(hidden_states, tuple):
|
|
||||||
hidden_states = hidden_states[0]
|
|
||||||
|
|
||||||
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
|
|||||||
get_global_graph_memory_pool,
|
get_global_graph_memory_pool,
|
||||||
model_capture_mode,
|
model_capture_mode,
|
||||||
set_global_graph_memory_pool,
|
set_global_graph_memory_pool,
|
||||||
|
set_is_extend_in_batch,
|
||||||
set_torch_compile_config,
|
set_torch_compile_config,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
@@ -263,6 +264,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
# Clean intermediate result cache for DP attention
|
# Clean intermediate result cache for DP attention
|
||||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||||
|
set_is_extend_in_batch(False)
|
||||||
|
|
||||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from sglang.srt.model_executor.cuda_graph_runner import (
|
|||||||
get_global_graph_memory_pool,
|
get_global_graph_memory_pool,
|
||||||
model_capture_mode,
|
model_capture_mode,
|
||||||
set_global_graph_memory_pool,
|
set_global_graph_memory_pool,
|
||||||
|
set_is_extend_in_batch,
|
||||||
set_torch_compile_config,
|
set_torch_compile_config,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
@@ -294,6 +295,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
# Clean intermediate result cache for DP attention
|
# Clean intermediate result cache for DP attention
|
||||||
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
||||||
|
set_is_extend_in_batch(False)
|
||||||
|
|
||||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||||
|
|||||||
@@ -1000,3 +1000,7 @@ class MaybeTboDeepEPDispatcher:
|
|||||||
|
|
||||||
def combine_b(self, **kwargs):
|
def combine_b(self, **kwargs):
|
||||||
return self._execute("combine_b", **kwargs)
|
return self._execute("combine_b", **kwargs)
|
||||||
|
|
||||||
|
def set_quant_config(self, quant_config: dict):
|
||||||
|
for inner in self._inners:
|
||||||
|
inner.set_quant_config(quant_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user