DP Attention with Auto DeepEP Dispatch (#7222)
This commit is contained in:
@@ -772,7 +772,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.last_batch_in_queue = last_batch_in_queue
|
||||
|
||||
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
batch = self.prepare_mlp_sync_batch(batch)
|
||||
result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
|
||||
@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
batch = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
batch = self.get_new_batch_prefill()
|
||||
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch, _ = self.prepare_mlp_sync_batch(batch)
|
||||
batch = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
|
||||
@@ -42,7 +42,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
ceil_div,
|
||||
@@ -1178,12 +1178,14 @@ class DeepEPMoE(EPMoE):
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
num_recv_tokens_per_expert: List[int],
|
||||
forward_mode: ForwardMode,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
if _use_aiter:
|
||||
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||
forward_batch.is_extend_in_batch
|
||||
)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
return self.forward_deepgemm_contiguous(
|
||||
|
||||
@@ -34,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_post_reorder_triton_kernel,
|
||||
deepep_run_moe_deep_preprocess,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||
|
||||
@@ -686,21 +686,21 @@ class DeepEPDispatcher:
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_mode: ForwardMode = None,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
||||
inner_state = self._get_impl(forward_mode).dispatch_a(
|
||||
inner_state = self._get_impl(forward_batch).dispatch_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
)
|
||||
self._dispatch_intermediate_state = forward_mode, inner_state
|
||||
self._dispatch_intermediate_state = forward_batch, inner_state
|
||||
|
||||
def dispatch_b(self):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
||||
forward_mode, inner_state = self._dispatch_intermediate_state
|
||||
forward_batch, inner_state = self._dispatch_intermediate_state
|
||||
del self._dispatch_intermediate_state
|
||||
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
||||
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
||||
|
||||
def combine(self, *args, **kwargs) -> Tuple:
|
||||
self.combine_a(*args, **kwargs)
|
||||
@@ -712,24 +712,26 @@ class DeepEPDispatcher:
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_mode: ForwardMode,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||
inner_state = self._get_impl(forward_mode).combine_a(
|
||||
inner_state = self._get_impl(forward_batch).combine_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
)
|
||||
self._combine_intermediate_state = forward_mode, inner_state
|
||||
self._combine_intermediate_state = forward_batch, inner_state
|
||||
|
||||
def combine_b(self):
|
||||
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
||||
forward_mode, inner_state = self._combine_intermediate_state
|
||||
forward_batch, inner_state = self._combine_intermediate_state
|
||||
del self._combine_intermediate_state
|
||||
return self._get_impl(forward_mode).combine_b(*inner_state)
|
||||
return self._get_impl(forward_batch).combine_b(*inner_state)
|
||||
|
||||
def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||
forward_batch.is_extend_in_batch
|
||||
)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
return self._normal_dispatcher
|
||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||
|
||||
@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||
is_extend_in_batch: bool = False
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
is_extend_in_batch: bool = False
|
||||
tbo_split_seq_index: Optional[int] = None
|
||||
@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
token_ids_logprobs=self.token_ids_logprobs,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
is_extend_in_batch=self.is_extend_in_batch,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
tbo_split_seq_index=self.tbo_split_seq_index,
|
||||
global_forward_mode=self.global_forward_mode,
|
||||
@@ -1798,6 +1800,7 @@ class ModelWorkerBatch:
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]]
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
is_extend_in_batch: bool
|
||||
can_run_dp_cuda_graph: bool
|
||||
tbo_split_seq_index: Optional[int]
|
||||
global_forward_mode: Optional[ForwardMode]
|
||||
|
||||
@@ -1490,7 +1490,7 @@ class Scheduler(
|
||||
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
||||
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
||||
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
||||
new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
|
||||
new_batch = self.prepare_mlp_sync_batch(new_batch)
|
||||
need_dp_attn_preparation = new_batch is None
|
||||
|
||||
if new_batch is not None:
|
||||
@@ -1506,7 +1506,7 @@ class Scheduler(
|
||||
|
||||
# Handle DP attention
|
||||
if need_dp_attn_preparation:
|
||||
ret, _ = self.prepare_mlp_sync_batch(ret)
|
||||
ret = self.prepare_mlp_sync_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -1923,8 +1923,7 @@ class Scheduler(
|
||||
if not disable_cuda_graph:
|
||||
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
||||
|
||||
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
|
||||
return local_batch, any(is_extend_in_batch)
|
||||
return local_batch
|
||||
|
||||
def get_idle_batch(self):
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
|
||||
@@ -254,6 +254,7 @@ class ForwardBatch:
|
||||
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
||||
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
is_extend_in_batch: bool = False
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
|
||||
@@ -299,6 +300,7 @@ class ForwardBatch:
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
token_ids_logprobs=batch.token_ids_logprobs,
|
||||
is_extend_in_batch=batch.is_extend_in_batch,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
global_forward_mode=batch.global_forward_mode,
|
||||
lora_paths=batch.lora_paths,
|
||||
|
||||
@@ -558,7 +558,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_mode=forward_mode,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -569,14 +569,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_mode=forward_mode,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
hidden_states=final_hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_mode=forward_mode,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if shared_output is not None:
|
||||
@@ -651,7 +651,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states=state.hidden_states_mlp_input,
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
@@ -683,7 +683,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
forward_batch=state.forward_batch,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
@@ -692,7 +692,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.pop("topk_idx_dispatched"),
|
||||
topk_weights=state.pop("topk_weights_dispatched"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
@@ -1881,7 +1881,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
and hidden_states.shape[0] == 0
|
||||
):
|
||||
state.hidden_states_mlp_output = self.mlp(
|
||||
hidden_states, state.forward_batch.forward_mode
|
||||
hidden_states, state.forward_batch
|
||||
)
|
||||
else:
|
||||
state.hidden_states_mlp_output = hidden_states
|
||||
|
||||
@@ -229,7 +229,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_mode=forward_mode,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -240,14 +240,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_mode=forward_mode,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
hidden_states=final_hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_mode=forward_mode,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
@@ -293,7 +293,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
forward_batch=state.forward_batch,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
@@ -334,7 +334,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.pop("topk_idx_dispatched"),
|
||||
topk_weights=state.pop("topk_weights_dispatched"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
@@ -647,9 +647,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
|
||||
def op_mlp(self, state):
|
||||
hidden_states = state.pop("hidden_states_mlp_input")
|
||||
state.hidden_states_mlp_output = self.mlp(
|
||||
hidden_states, state.forward_batch.forward_mode
|
||||
)
|
||||
state.hidden_states_mlp_output = self.mlp(hidden_states, state.forward_batch)
|
||||
|
||||
def op_comm_postprocess_layer(self, state):
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
|
||||
@@ -418,10 +418,6 @@ class ServerArgs:
|
||||
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
if self.deepep_mode == "auto":
|
||||
assert (
|
||||
not self.enable_dp_attention
|
||||
), "DeepEP MoE `auto` mode is not supported with DP Attention."
|
||||
if self.deepep_mode == "normal":
|
||||
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
||||
self.disable_cuda_graph = True
|
||||
|
||||
@@ -13,7 +13,7 @@ from sglang.srt.layers.communicator import (
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
@@ -272,7 +272,11 @@ class TboCudaGraphRunnerPlugin:
|
||||
|
||||
class TboDPAttentionPreparer:
|
||||
def prepare_all_gather(
|
||||
self, local_batch, deepep_mode, enable_deepep_moe, enable_two_batch_overlap
|
||||
self,
|
||||
local_batch: ScheduleBatch,
|
||||
deepep_mode: DeepEPMode,
|
||||
enable_deepep_moe: bool,
|
||||
enable_two_batch_overlap: bool,
|
||||
):
|
||||
self.enable_two_batch_overlap = enable_two_batch_overlap
|
||||
|
||||
@@ -294,7 +298,7 @@ class TboDPAttentionPreparer:
|
||||
extend_lens=local_batch.extend_lens,
|
||||
token_num_per_seq=token_num_per_seq,
|
||||
)
|
||||
resolved_deepep_mode = deepep_mode.resolve(local_batch.forward_mode)
|
||||
resolved_deepep_mode = deepep_mode.resolve(local_batch.is_extend_in_batch)
|
||||
local_can_run_tbo = (self.local_tbo_split_seq_index is not None) and not (
|
||||
(
|
||||
local_batch.forward_mode.is_extend()
|
||||
|
||||
@@ -2202,14 +2202,14 @@ class DeepEPMode(Enum):
|
||||
def enable_low_latency(self):
|
||||
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
|
||||
|
||||
def resolve(self, forward_mode):
|
||||
def resolve(self, is_extend_in_batch: bool):
|
||||
if self != DeepEPMode.auto:
|
||||
return self
|
||||
|
||||
if forward_mode.is_decode():
|
||||
return DeepEPMode.low_latency
|
||||
else:
|
||||
if is_extend_in_batch:
|
||||
return DeepEPMode.normal
|
||||
else:
|
||||
return DeepEPMode.low_latency
|
||||
|
||||
|
||||
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||
|
||||
Reference in New Issue
Block a user