qwen3moe support two batch overlap (#6598)
This commit is contained in:
@@ -68,6 +68,7 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -442,12 +443,22 @@ class Qwen2MoeModel(nn.Module):
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
if forward_batch.can_run_tbo:
|
||||
hidden_states, residual = model_forward_maybe_tbo(
|
||||
layers=self.layers,
|
||||
enable_tbo=True,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
)
|
||||
else:
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
if not self.pp_group.is_last_rank:
|
||||
return PPProxyTensors(
|
||||
{
|
||||
|
||||
@@ -68,6 +68,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.expert_distribution import (
|
||||
get_global_expert_distribution_recorder,
|
||||
)
|
||||
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -79,6 +82,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
|
||||
|
||||
Qwen3MoeConfig = None
|
||||
@@ -137,7 +141,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.renormalize = config.norm_topk_prob
|
||||
|
||||
self.deepep_dispatcher = DeepEPDispatcher(
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
router_topk=self.top_k,
|
||||
permute_fusion=True,
|
||||
@@ -217,9 +221,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
masked_m,
|
||||
expected_m,
|
||||
) = self.deepep_dispatcher.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
final_hidden_states = self.experts(
|
||||
@@ -235,13 +239,105 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
)
|
||||
if self.ep_size > 1:
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
final_hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
forward_mode,
|
||||
hidden_states=final_hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
def op_gate(self, state):
|
||||
if is_non_idle_and_non_empty(
|
||||
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
||||
):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
state.router_logits, _ = self.gate(state.hidden_states_mlp_input)
|
||||
else:
|
||||
state.router_logits = None
|
||||
|
||||
def op_select_experts(self, state):
|
||||
router_logits = state.pop("router_logits")
|
||||
hidden_states = state.hidden_states_mlp_input
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=self.renormalize,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
state.topk_weights_local = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
def op_dispatch_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
self.deepep_dispatcher.dispatch_a(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_dispatch_b(self, state):
|
||||
if self.ep_size > 1:
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
(
|
||||
state.hidden_states_experts_input,
|
||||
state.topk_idx_dispatched,
|
||||
state.topk_weights_dispatched,
|
||||
state.reorder_topk_ids,
|
||||
state.num_recv_tokens_per_expert,
|
||||
state.seg_indptr,
|
||||
state.masked_m,
|
||||
state.expected_m,
|
||||
) = self.deepep_dispatcher.dispatch_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_experts(self, state):
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_experts_input"),
|
||||
topk_idx=state.topk_idx_dispatched,
|
||||
topk_weights=state.topk_weights_dispatched,
|
||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
||||
seg_indptr=state.pop("seg_indptr"),
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
if self.ep_size > 1:
|
||||
self.deepep_dispatcher.combine_a(
|
||||
hidden_states=state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.pop("topk_idx_dispatched"),
|
||||
topk_weights=state.pop("topk_weights_dispatched"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_combine_b(self, state):
|
||||
if self.ep_size > 1:
|
||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
||||
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
||||
)
|
||||
|
||||
def op_output(self, state):
|
||||
state.hidden_states_mlp_output = state.pop("hidden_states_after_combine")
|
||||
|
||||
|
||||
class Qwen3MoeAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -339,19 +435,53 @@ class Qwen3MoeAttention(nn.Module):
|
||||
k = k_by_head.view(k.shape)
|
||||
return q, k
|
||||
|
||||
def op_prepare(self, state):
|
||||
state.attn_intermediate_state = self.forward_prepare(
|
||||
positions=state.positions,
|
||||
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
||||
forward_batch=state.forward_batch,
|
||||
)
|
||||
|
||||
def op_core(self, state):
|
||||
state.hidden_states_after_attn = self.forward_core(
|
||||
state.pop("attn_intermediate_state")
|
||||
)
|
||||
|
||||
def forward_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
if hidden_states.shape[0] == 0:
|
||||
return hidden_states, forward_batch, None
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
inner_state = q, k, v, forward_batch
|
||||
return None, forward_batch, inner_state
|
||||
|
||||
def forward_core(self, intermediate_state):
|
||||
hidden_states, forward_batch, inner_state = intermediate_state
|
||||
if inner_state is None:
|
||||
return hidden_states
|
||||
attn_output = self.attn(*inner_state)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
s = self.forward_prepare(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
return self.forward_core(s)
|
||||
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
@@ -462,6 +592,65 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def op_comm_prepare_attn(
|
||||
self,
|
||||
state,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
tbo_subbatch_index: Optional[int] = None,
|
||||
):
|
||||
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
||||
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
|
||||
)
|
||||
state.update(
|
||||
dict(
|
||||
forward_batch=forward_batch,
|
||||
positions=positions,
|
||||
tbo_subbatch_index=tbo_subbatch_index,
|
||||
)
|
||||
)
|
||||
|
||||
def op_comm_prepare_mlp(self, state):
|
||||
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
||||
self.layer_communicator.prepare_mlp(
|
||||
state.pop("hidden_states_after_attn"),
|
||||
state.pop("residual_after_input_ln"),
|
||||
state.forward_batch,
|
||||
)
|
||||
)
|
||||
|
||||
def op_mlp(self, state):
|
||||
hidden_states = state.pop("hidden_states_mlp_input")
|
||||
state.hidden_states_mlp_output = self.mlp(
|
||||
hidden_states, state.forward_batch.forward_mode
|
||||
)
|
||||
|
||||
def op_comm_postprocess_layer(self, state):
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
state.pop("hidden_states_mlp_output"),
|
||||
state.pop("residual_after_comm_pre_mlp"),
|
||||
state.forward_batch,
|
||||
)
|
||||
|
||||
output = dict(
|
||||
positions=state.positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=state.forward_batch,
|
||||
tbo_subbatch_index=state.tbo_subbatch_index,
|
||||
)
|
||||
|
||||
state.clear(
|
||||
expect_keys={
|
||||
"positions",
|
||||
"forward_batch",
|
||||
"tbo_subbatch_index",
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen3MoeModel(Qwen2MoeModel):
|
||||
def __init__(
|
||||
|
||||
@@ -32,12 +32,27 @@ class OperationsStrategy:
|
||||
layers: torch.nn.ModuleList,
|
||||
forward_mode: ForwardMode,
|
||||
) -> "OperationsStrategy":
|
||||
return OperationsStrategy.concat(
|
||||
[
|
||||
_compute_layer_operations_strategy_tbo(layer, forward_mode)
|
||||
for layer in layers
|
||||
]
|
||||
)
|
||||
layer_name = layers[0].__class__.__name__
|
||||
if layer_name == "DeepseekV2DecoderLayer":
|
||||
return OperationsStrategy.concat(
|
||||
[
|
||||
_compute_moe_deepseek_layer_operations_strategy_tbo(
|
||||
layer, forward_mode
|
||||
)
|
||||
for layer in layers
|
||||
]
|
||||
)
|
||||
elif layer_name == "Qwen3MoeDecoderLayer":
|
||||
return OperationsStrategy.concat(
|
||||
[
|
||||
_compute_moe_qwen3_layer_operations_strategy_tbo(
|
||||
layer, forward_mode
|
||||
)
|
||||
for layer in layers
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _assert_all_same(items: List):
|
||||
@@ -45,8 +60,11 @@ def _assert_all_same(items: List):
|
||||
return items[0]
|
||||
|
||||
|
||||
# -------------------------------- Strategy for DeepSeek ---------------------------------------
|
||||
|
||||
|
||||
# TODO can refactor to make it more fancy if we have more complex strategies
|
||||
def _compute_layer_operations_strategy_tbo(
|
||||
def _compute_moe_deepseek_layer_operations_strategy_tbo(
|
||||
layer: torch.nn.Module,
|
||||
forward_mode: ForwardMode,
|
||||
) -> OperationsStrategy:
|
||||
@@ -114,3 +132,76 @@ def _compute_moe_deepseek_blog_decode(layer):
|
||||
operations.YieldOperation(),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------- Strategy for Qwen3 ---------------------------------------
|
||||
|
||||
|
||||
# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for
|
||||
# convenience to adjust strategy
|
||||
def _compute_moe_qwen3_layer_operations_strategy_tbo(
|
||||
layer: torch.nn.Module,
|
||||
forward_mode: ForwardMode,
|
||||
) -> OperationsStrategy:
|
||||
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
|
||||
if forward_mode == ForwardMode.EXTEND:
|
||||
return _compute_moe_qwen3_prefill(layer)
|
||||
elif forward_mode == ForwardMode.DECODE:
|
||||
return _compute_moe_qwen3_decode(layer)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported {forward_mode=}")
|
||||
|
||||
|
||||
def _compute_moe_qwen3_prefill(layer):
|
||||
device_properties = torch.cuda.get_device_properties(device="cuda")
|
||||
total_num_sms = device_properties.multi_processor_count
|
||||
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
|
||||
|
||||
return OperationsStrategy(
|
||||
deep_gemm_num_sms=deep_gemm_num_sms,
|
||||
tbo_delta_stages=0,
|
||||
operations=[
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.self_attn.op_prepare,
|
||||
layer.self_attn.op_core,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.mlp.op_gate,
|
||||
layer.mlp.op_select_experts,
|
||||
layer.mlp.op_dispatch_a,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_dispatch_b,
|
||||
layer.mlp.op_experts,
|
||||
layer.mlp.op_combine_a,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_combine_b,
|
||||
layer.mlp.op_output,
|
||||
layer.op_comm_postprocess_layer,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _compute_moe_qwen3_decode(layer):
|
||||
return OperationsStrategy(
|
||||
deep_gemm_num_sms=None,
|
||||
tbo_delta_stages=2,
|
||||
operations=[
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.self_attn.op_prepare,
|
||||
operations.YieldOperation(),
|
||||
layer.self_attn.op_core,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.mlp.op_gate,
|
||||
layer.mlp.op_select_experts,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_dispatch_a,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_dispatch_b,
|
||||
layer.mlp.op_experts,
|
||||
layer.mlp.op_combine_a,
|
||||
operations.YieldOperation(),
|
||||
layer.mlp.op_combine_b,
|
||||
layer.mlp.op_output,
|
||||
layer.op_comm_postprocess_layer,
|
||||
operations.YieldOperation(),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -356,14 +356,14 @@ def model_forward_maybe_tbo(
|
||||
forward_batch: ForwardBatch,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
zero_allocator: Optional[BumpAllocator] = None,
|
||||
):
|
||||
inputs = dict(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
residual=residual,
|
||||
zero_allocator=zero_allocator,
|
||||
**(dict(zero_allocator=zero_allocator) if zero_allocator is not None else {}),
|
||||
)
|
||||
operations_strategy = OperationsStrategy.init_new_tbo(
|
||||
layers, forward_batch.global_forward_mode
|
||||
@@ -401,7 +401,7 @@ def _model_forward_tbo_split_inputs(
|
||||
residual: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
zero_allocator: BumpAllocator,
|
||||
zero_allocator: Optional[BumpAllocator] = None,
|
||||
) -> List[Dict]:
|
||||
return [
|
||||
dict(
|
||||
@@ -412,7 +412,11 @@ def _model_forward_tbo_split_inputs(
|
||||
output_forward_batch=output_forward_batch,
|
||||
tbo_subbatch_index=tbo_subbatch_index,
|
||||
),
|
||||
zero_allocator=zero_allocator,
|
||||
**(
|
||||
dict(zero_allocator=zero_allocator)
|
||||
if zero_allocator is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
for tbo_subbatch_index, output_forward_batch in enumerate(
|
||||
forward_batch.tbo_children
|
||||
|
||||
Reference in New Issue
Block a user