DP Attention with Auto DeepEP Dispatch (#7222)
This commit is contained in:
@@ -772,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.last_batch_in_queue = last_batch_in_queue
|
self.last_batch_in_queue = last_batch_in_queue
|
||||||
|
|
||||||
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
||||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
batch = self.prepare_mlp_sync_batch(batch)
|
||||||
result = None
|
result = None
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
|
|||||||
@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
batch = self.get_new_batch_prefill()
|
batch = self.get_new_batch_prefill()
|
||||||
|
|
||||||
if require_mlp_sync(self.server_args):
|
if require_mlp_sync(self.server_args):
|
||||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
batch = self.prepare_mlp_sync_batch(batch)
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
batch = self.get_new_batch_prefill()
|
batch = self.get_new_batch_prefill()
|
||||||
|
|
||||||
if require_mlp_sync(self.server_args):
|
if require_mlp_sync(self.server_args):
|
||||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
batch = self.prepare_mlp_sync_batch(batch)
|
||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
if batch:
|
if batch:
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
DeepEPMode,
|
DeepEPMode,
|
||||||
ceil_div,
|
ceil_div,
|
||||||
@@ -1178,12 +1178,14 @@ class DeepEPMoE(EPMoE):
|
|||||||
masked_m: torch.Tensor,
|
masked_m: torch.Tensor,
|
||||||
expected_m: int,
|
expected_m: int,
|
||||||
num_recv_tokens_per_expert: List[int],
|
num_recv_tokens_per_expert: List[int],
|
||||||
forward_mode: ForwardMode,
|
forward_batch: ForwardBatch,
|
||||||
):
|
):
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||||
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||||
|
forward_batch.is_extend_in_batch
|
||||||
|
)
|
||||||
if resolved_deepep_mode == DeepEPMode.normal:
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||||
return self.forward_deepgemm_contiguous(
|
return self.forward_deepgemm_contiguous(
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
deepep_post_reorder_triton_kernel,
|
deepep_post_reorder_triton_kernel,
|
||||||
deepep_run_moe_deep_preprocess,
|
deepep_run_moe_deep_preprocess,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
|
|
||||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||||
|
|
||||||
@@ -686,21 +686,21 @@ class DeepEPDispatcher:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_mode: ForwardMode = None,
|
forward_batch: ForwardBatch,
|
||||||
):
|
):
|
||||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||||
inner_state = self._get_impl(forward_mode).dispatch_a(
|
inner_state = self._get_impl(forward_batch).dispatch_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
)
|
)
|
||||||
self._dispatch_intermediate_state = forward_mode, inner_state
|
self._dispatch_intermediate_state = forward_batch, inner_state
|
||||||
|
|
||||||
def dispatch_b(self):
|
def dispatch_b(self):
|
||||||
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||||
forward_mode, inner_state = self._dispatch_intermediate_state
|
forward_batch, inner_state = self._dispatch_intermediate_state
|
||||||
del self._dispatch_intermediate_state
|
del self._dispatch_intermediate_state
|
||||||
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
||||||
|
|
||||||
def combine(self, *args, **kwargs) -> Tuple:
|
def combine(self, *args, **kwargs) -> Tuple:
|
||||||
self.combine_a(*args, **kwargs)
|
self.combine_a(*args, **kwargs)
|
||||||
@@ -712,24 +712,26 @@ class DeepEPDispatcher:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
forward_mode: ForwardMode,
|
forward_batch: ForwardBatch,
|
||||||
):
|
):
|
||||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||||
inner_state = self._get_impl(forward_mode).combine_a(
|
inner_state = self._get_impl(forward_batch).combine_a(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
)
|
)
|
||||||
self._combine_intermediate_state = forward_mode, inner_state
|
self._combine_intermediate_state = forward_batch, inner_state
|
||||||
|
|
||||||
def combine_b(self):
|
def combine_b(self):
|
||||||
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||||
forward_mode, inner_state = self._combine_intermediate_state
|
forward_batch, inner_state = self._combine_intermediate_state
|
||||||
del self._combine_intermediate_state
|
del self._combine_intermediate_state
|
||||||
return self._get_impl(forward_mode).combine_b(*inner_state)
|
return self._get_impl(forward_batch).combine_b(*inner_state)
|
||||||
|
|
||||||
def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
|
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||||
|
forward_batch.is_extend_in_batch
|
||||||
|
)
|
||||||
if resolved_deepep_mode == DeepEPMode.normal:
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
return self._normal_dispatcher
|
return self._normal_dispatcher
|
||||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||||
|
|||||||
@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# For DP attention
|
# For DP attention
|
||||||
global_num_tokens: Optional[List[int]] = None
|
global_num_tokens: Optional[List[int]] = None
|
||||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||||
|
is_extend_in_batch: bool = False
|
||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
is_extend_in_batch: bool = False
|
is_extend_in_batch: bool = False
|
||||||
tbo_split_seq_index: Optional[int] = None
|
tbo_split_seq_index: Optional[int] = None
|
||||||
@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
token_ids_logprobs=self.token_ids_logprobs,
|
token_ids_logprobs=self.token_ids_logprobs,
|
||||||
global_num_tokens=self.global_num_tokens,
|
global_num_tokens=self.global_num_tokens,
|
||||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||||
|
is_extend_in_batch=self.is_extend_in_batch,
|
||||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||||
tbo_split_seq_index=self.tbo_split_seq_index,
|
tbo_split_seq_index=self.tbo_split_seq_index,
|
||||||
global_forward_mode=self.global_forward_mode,
|
global_forward_mode=self.global_forward_mode,
|
||||||
@@ -1798,6 +1800,7 @@ class ModelWorkerBatch:
|
|||||||
# For DP attention
|
# For DP attention
|
||||||
global_num_tokens: Optional[List[int]]
|
global_num_tokens: Optional[List[int]]
|
||||||
global_num_tokens_for_logprob: Optional[List[int]]
|
global_num_tokens_for_logprob: Optional[List[int]]
|
||||||
|
is_extend_in_batch: bool
|
||||||
can_run_dp_cuda_graph: bool
|
can_run_dp_cuda_graph: bool
|
||||||
tbo_split_seq_index: Optional[int]
|
tbo_split_seq_index: Optional[int]
|
||||||
global_forward_mode: Optional[ForwardMode]
|
global_forward_mode: Optional[ForwardMode]
|
||||||
|
|||||||
@@ -1490,7 +1490,7 @@ class Scheduler(
|
|||||||
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
||||||
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
||||||
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
||||||
new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
|
new_batch = self.prepare_mlp_sync_batch(new_batch)
|
||||||
need_dp_attn_preparation = new_batch is None
|
need_dp_attn_preparation = new_batch is None
|
||||||
|
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
@@ -1506,7 +1506,7 @@ class Scheduler(
|
|||||||
|
|
||||||
# Handle DP attention
|
# Handle DP attention
|
||||||
if need_dp_attn_preparation:
|
if need_dp_attn_preparation:
|
||||||
ret, _ = self.prepare_mlp_sync_batch(ret)
|
ret = self.prepare_mlp_sync_batch(ret)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -1923,8 +1923,7 @@ class Scheduler(
|
|||||||
if not disable_cuda_graph:
|
if not disable_cuda_graph:
|
||||||
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
||||||
|
|
||||||
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
|
return local_batch
|
||||||
return local_batch, any(is_extend_in_batch)
|
|
||||||
|
|
||||||
def get_idle_batch(self):
|
def get_idle_batch(self):
|
||||||
idle_batch = ScheduleBatch.init_new(
|
idle_batch = ScheduleBatch.init_new(
|
||||||
|
|||||||
@@ -254,6 +254,7 @@ class ForwardBatch:
|
|||||||
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
||||||
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
||||||
gathered_buffer: Optional[torch.Tensor] = None
|
gathered_buffer: Optional[torch.Tensor] = None
|
||||||
|
is_extend_in_batch: bool = False
|
||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
global_forward_mode: Optional[ForwardMode] = None
|
global_forward_mode: Optional[ForwardMode] = None
|
||||||
|
|
||||||
@@ -299,6 +300,7 @@ class ForwardBatch:
|
|||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
token_ids_logprobs=batch.token_ids_logprobs,
|
token_ids_logprobs=batch.token_ids_logprobs,
|
||||||
|
is_extend_in_batch=batch.is_extend_in_batch,
|
||||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||||
global_forward_mode=batch.global_forward_mode,
|
global_forward_mode=batch.global_forward_mode,
|
||||||
lora_paths=batch.lora_paths,
|
lora_paths=batch.lora_paths,
|
||||||
|
|||||||
@@ -558,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_mode=forward_mode,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@@ -569,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
masked_m=masked_m,
|
masked_m=masked_m,
|
||||||
expected_m=expected_m,
|
expected_m=expected_m,
|
||||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||||
forward_mode=forward_mode,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
final_hidden_states = self.deepep_dispatcher.combine(
|
final_hidden_states = self.deepep_dispatcher.combine(
|
||||||
hidden_states=final_hidden_states,
|
hidden_states=final_hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_mode=forward_mode,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
@@ -651,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states=state.hidden_states_mlp_input,
|
hidden_states=state.hidden_states_mlp_input,
|
||||||
topk_idx=state.pop("topk_idx_local"),
|
topk_idx=state.pop("topk_idx_local"),
|
||||||
topk_weights=state.pop("topk_weights_local"),
|
topk_weights=state.pop("topk_weights_local"),
|
||||||
forward_mode=state.forward_batch.forward_mode,
|
forward_batch=state.forward_batch,
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -683,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
masked_m=state.pop("masked_m"),
|
masked_m=state.pop("masked_m"),
|
||||||
expected_m=state.pop("expected_m"),
|
expected_m=state.pop("expected_m"),
|
||||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||||
forward_mode=state.forward_batch.forward_mode,
|
forward_batch=state.forward_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_combine_a(self, state):
|
def op_combine_a(self, state):
|
||||||
@@ -692,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states=state.pop("hidden_states_experts_output"),
|
hidden_states=state.pop("hidden_states_experts_output"),
|
||||||
topk_idx=state.pop("topk_idx_dispatched"),
|
topk_idx=state.pop("topk_idx_dispatched"),
|
||||||
topk_weights=state.pop("topk_weights_dispatched"),
|
topk_weights=state.pop("topk_weights_dispatched"),
|
||||||
forward_mode=state.forward_batch.forward_mode,
|
forward_batch=state.forward_batch,
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1881,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
and hidden_states.shape[0] == 0
|
and hidden_states.shape[0] == 0
|
||||||
):
|
):
|
||||||
state.hidden_states_mlp_output = self.mlp(
|
state.hidden_states_mlp_output = self.mlp(
|
||||||
hidden_states, state.forward_batch.forward_mode
|
hidden_states, state.forward_batch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state.hidden_states_mlp_output = hidden_states
|
state.hidden_states_mlp_output = hidden_states
|
||||||
|
|||||||
@@ -229,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_mode=forward_mode,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@@ -240,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
masked_m=masked_m,
|
masked_m=masked_m,
|
||||||
expected_m=expected_m,
|
expected_m=expected_m,
|
||||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||||
forward_mode=forward_mode,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
final_hidden_states = self.deepep_dispatcher.combine(
|
final_hidden_states = self.deepep_dispatcher.combine(
|
||||||
hidden_states=final_hidden_states,
|
hidden_states=final_hidden_states,
|
||||||
topk_idx=topk_idx,
|
topk_idx=topk_idx,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
forward_mode=forward_mode,
|
forward_batch=forward_batch,
|
||||||
)
|
)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
@@ -293,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||||
topk_idx=state.pop("topk_idx_local"),
|
topk_idx=state.pop("topk_idx_local"),
|
||||||
topk_weights=state.pop("topk_weights_local"),
|
topk_weights=state.pop("topk_weights_local"),
|
||||||
forward_mode=state.forward_batch.forward_mode,
|
forward_batch=state.forward_batch,
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -325,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
masked_m=state.pop("masked_m"),
|
masked_m=state.pop("masked_m"),
|
||||||
expected_m=state.pop("expected_m"),
|
expected_m=state.pop("expected_m"),
|
||||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||||
forward_mode=state.forward_batch.forward_mode,
|
forward_batch=state.forward_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def op_combine_a(self, state):
|
def op_combine_a(self, state):
|
||||||
@@ -334,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
hidden_states=state.pop("hidden_states_experts_output"),
|
hidden_states=state.pop("hidden_states_experts_output"),
|
||||||
topk_idx=state.pop("topk_idx_dispatched"),
|
topk_idx=state.pop("topk_idx_dispatched"),
|
||||||
topk_weights=state.pop("topk_weights_dispatched"),
|
topk_weights=state.pop("topk_weights_dispatched"),
|
||||||
forward_mode=state.forward_batch.forward_mode,
|
forward_batch=state.forward_batch,
|
||||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -647,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def op_mlp(self, state):
|
def op_mlp(self, state):
|
||||||
hidden_states = state.pop("hidden_states_mlp_input")
|
hidden_states = state.pop("hidden_states_mlp_input")
|
||||||
state.hidden_states_mlp_output = self.mlp(
|
state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
|
||||||
hidden_states, state.forward_batch.forward_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
def op_comm_postprocess_layer(self, state):
|
def op_comm_postprocess_layer(self, state):
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
|
|||||||
@@ -418,10 +418,6 @@ class ServerArgs:
|
|||||||
|
|
||||||
# DeepEP MoE
|
# DeepEP MoE
|
||||||
if self.enable_deepep_moe:
|
if self.enable_deepep_moe:
|
||||||
if self.deepep_mode == "auto":
|
|
||||||
assert (
|
|
||||||
not self.enable_dp_attention
|
|
||||||
), "DeepEP MoE `auto` mode is not supported with DP Attention."
|
|
||||||
if self.deepep_mode == "normal":
|
if self.deepep_mode == "normal":
|
||||||
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
||||||
self.disable_cuda_graph = True
|
self.disable_cuda_graph = True
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from sglang.srt.layers.communicator import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||||
from sglang.srt.operations_strategy import OperationsStrategy
|
from sglang.srt.operations_strategy import OperationsStrategy
|
||||||
@@ -272,7 +272,11 @@ class TboCudaGraphRunnerPlugin:
|
|||||||
|
|
||||||
class TboDPAttentionPreparer:
|
class TboDPAttentionPreparer:
|
||||||
def prepare_all_gather(
|
def prepare_all_gather(
|
||||||
self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap
|
self,
|
||||||
|
local_batch: ScheduleBatch,
|
||||||
|
deepep_mode: DeepEPMode,
|
||||||
|
enable_deepep_moe: bool,
|
||||||
|
enable_two_batch_overlap: bool,
|
||||||
):
|
):
|
||||||
self.enable_two_batch_overlap = enable_two_batch_overlap
|
self.enable_two_batch_overlap = enable_two_batch_overlap
|
||||||
|
|
||||||
@@ -294,7 +298,7 @@ class TboDPAttentionPreparer:
|
|||||||
extend_lens=local_batch.extend_lens,
|
extend_lens=local_batch.extend_lens,
|
||||||
token_num_per_seq=token_num_per_seq,
|
token_num_per_seq=token_num_per_seq,
|
||||||
)
|
)
|
||||||
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
|
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
|
||||||
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
||||||
(
|
(
|
||||||
local_batch.forward_mode.is_extend()
|
local_batch.forward_mode.is_extend()
|
||||||
|
|||||||
@@ -2202,14 +2202,14 @@ class DeepEPMode(Enum):
|
|||||||
def enable_low_latency(self):
|
def enable_low_latency(self):
|
||||||
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
|
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
|
||||||
|
|
||||||
def resolve(self, forward_mode):
|
def resolve(self, is_extend_in_batch: bool):
|
||||||
if self != DeepEPMode.auto:
|
if self != DeepEPMode.auto:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if forward_mode.is_decode():
|
if is_extend_in_batch:
|
||||||
return DeepEPMode.low_latency
|
|
||||||
else:
|
|
||||||
return DeepEPMode.normal
|
return DeepEPMode.normal
|
||||||
|
else:
|
||||||
|
return DeepEPMode.low_latency
|
||||||
|
|
||||||
|
|
||||||
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||||
|
|||||||
@@ -539,8 +539,9 @@ class Test10(CustomTestCase):
|
|||||||
"8",
|
"8",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -593,8 +594,9 @@ class Test11(CustomTestCase):
|
|||||||
"4",
|
"4",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -647,8 +649,9 @@ class Test12(CustomTestCase):
|
|||||||
"8",
|
"8",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -700,8 +703,9 @@ class Test13(CustomTestCase):
|
|||||||
"1",
|
"1",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -756,8 +760,9 @@ class Test14(CustomTestCase):
|
|||||||
"1",
|
"1",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -812,8 +817,9 @@ class Test15(CustomTestCase):
|
|||||||
"1",
|
"1",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -867,8 +873,9 @@ class Test16(CustomTestCase):
|
|||||||
"--enable-dp-lm-head",
|
"--enable-dp-lm-head",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -922,8 +929,9 @@ class Test17(CustomTestCase):
|
|||||||
"--enable-dp-lm-head",
|
"--enable-dp-lm-head",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -979,8 +987,9 @@ class Test18(CustomTestCase):
|
|||||||
"--enable-dp-lm-head",
|
"--enable-dp-lm-head",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1036,8 +1045,9 @@ class Test19(CustomTestCase):
|
|||||||
"--enable-dp-lm-head",
|
"--enable-dp-lm-head",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"128",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2213,8 +2223,11 @@ class Test40(CustomTestCase):
|
|||||||
"8",
|
"8",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2277,8 +2290,11 @@ class Test41(CustomTestCase):
|
|||||||
"4",
|
"4",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2341,8 +2357,11 @@ class Test42(CustomTestCase):
|
|||||||
"8",
|
"8",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2404,8 +2423,11 @@ class Test43(CustomTestCase):
|
|||||||
"1",
|
"1",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2470,8 +2492,11 @@ class Test44(CustomTestCase):
|
|||||||
"1",
|
"1",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2536,8 +2561,11 @@ class Test45(CustomTestCase):
|
|||||||
"1",
|
"1",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2601,8 +2629,11 @@ class Test46(CustomTestCase):
|
|||||||
"--enable-dp-lm-head",
|
"--enable-dp-lm-head",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2666,8 +2697,11 @@ class Test47(CustomTestCase):
|
|||||||
"--enable-dp-lm-head",
|
"--enable-dp-lm-head",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2733,8 +2767,11 @@ class Test48(CustomTestCase):
|
|||||||
"--enable-dp-lm-head",
|
"--enable-dp-lm-head",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
@@ -2800,8 +2837,11 @@ class Test49(CustomTestCase):
|
|||||||
"--enable-dp-lm-head",
|
"--enable-dp-lm-head",
|
||||||
"--enable-deepep-moe",
|
"--enable-deepep-moe",
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
"normal",
|
"auto",
|
||||||
"--disable-cuda-graph",
|
"--cuda-graph-max-bs",
|
||||||
|
"32",
|
||||||
|
"--max-running-requests",
|
||||||
|
"32",
|
||||||
"--speculative-algo",
|
"--speculative-algo",
|
||||||
"NEXTN",
|
"NEXTN",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
|
|||||||
Reference in New Issue
Block a user