qwen3moe support two batch overlap (#6598)

This commit is contained in:
Yi Zhang
2025-05-26 14:08:16 +08:00
committed by GitHub
parent 16f69b1f65
commit f9bab3d591
5 changed files with 355 additions and 32 deletions

View File

@@ -32,12 +32,27 @@ class OperationsStrategy:
layers: torch.nn.ModuleList,
forward_mode: ForwardMode,
) -> "OperationsStrategy":
return OperationsStrategy.concat(
[
_compute_layer_operations_strategy_tbo(layer, forward_mode)
for layer in layers
]
)
layer_name = layers[0].__class__.__name__
if layer_name == "DeepseekV2DecoderLayer":
return OperationsStrategy.concat(
[
_compute_moe_deepseek_layer_operations_strategy_tbo(
layer, forward_mode
)
for layer in layers
]
)
elif layer_name == "Qwen3MoeDecoderLayer":
return OperationsStrategy.concat(
[
_compute_moe_qwen3_layer_operations_strategy_tbo(
layer, forward_mode
)
for layer in layers
]
)
else:
raise NotImplementedError
def _assert_all_same(items: List):
@@ -45,8 +60,11 @@ def _assert_all_same(items: List):
return items[0]
# -------------------------------- Strategy for DeepSeek ---------------------------------------
# TODO can refactor to make it more fancy if we have more complex strategies
def _compute_layer_operations_strategy_tbo(
def _compute_moe_deepseek_layer_operations_strategy_tbo(
layer: torch.nn.Module,
forward_mode: ForwardMode,
) -> OperationsStrategy:
@@ -114,3 +132,76 @@ def _compute_moe_deepseek_blog_decode(layer):
operations.YieldOperation(),
],
)
# -------------------------------- Strategy for Qwen3 ---------------------------------------
# TODO: unstable, current strategy is almost the same as DeepSeek, keep redundant code here for
# convenience to adjust strategy
def _compute_moe_qwen3_layer_operations_strategy_tbo(
layer: torch.nn.Module,
forward_mode: ForwardMode,
) -> OperationsStrategy:
assert layer.is_layer_sparse, "qwen3 moe only support sparse layers"
if forward_mode == ForwardMode.EXTEND:
return _compute_moe_qwen3_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
return _compute_moe_qwen3_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")
def _compute_moe_qwen3_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
tbo_delta_stages=0,
operations=[
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_select_experts,
layer.mlp.op_dispatch_a,
operations.YieldOperation(),
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
operations.YieldOperation(),
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
],
)
def _compute_moe_qwen3_decode(layer):
return OperationsStrategy(
deep_gemm_num_sms=None,
tbo_delta_stages=2,
operations=[
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
operations.YieldOperation(),
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_select_experts,
operations.YieldOperation(),
layer.mlp.op_dispatch_a,
operations.YieldOperation(),
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
operations.YieldOperation(),
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
operations.YieldOperation(),
],
)