Support overlapping two batches (#4068)

This commit is contained in:
fzyzcjy
2025-05-25 08:39:07 +08:00
committed by GitHub
parent f456037396
commit 0d47788025
13 changed files with 1145 additions and 129 deletions

View File

@@ -1,33 +1,116 @@
from dataclasses import dataclass
from typing import List, Optional
import torch
from sglang.srt import operations
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.operations import Operation
def compute_layer_operations(
@dataclass
class OperationsStrategy:
operations: List[Operation]
deep_gemm_num_sms: Optional[int] = None
tbo_delta_stages: Optional[int] = None
@classmethod
def concat(cls, items: List["OperationsStrategy"]) -> "OperationsStrategy":
return OperationsStrategy(
operations=[x for item in items for x in item.operations],
deep_gemm_num_sms=_assert_all_same(
[item.deep_gemm_num_sms for item in items]
),
tbo_delta_stages=_assert_all_same(
[item.tbo_delta_stages for item in items]
),
)
@staticmethod
def init_new_tbo(
layers: torch.nn.ModuleList,
forward_mode: ForwardMode,
) -> "OperationsStrategy":
return OperationsStrategy.concat(
[
_compute_layer_operations_strategy_tbo(layer, forward_mode)
for layer in layers
]
)
def _assert_all_same(items: List):
assert all(item == items[0] for item in items)
return items[0]
# TODO can refactor to make it more fancy if we have more complex strategies
def _compute_layer_operations_strategy_tbo(
layer: torch.nn.Module,
):
if not layer.is_layer_sparse:
return [
forward_mode: ForwardMode,
) -> OperationsStrategy:
assert layer.is_layer_sparse, "dense layer TBO not yet implemented"
if forward_mode == ForwardMode.EXTEND:
return _compute_moe_deepseek_blog_prefill(layer)
elif forward_mode == ForwardMode.DECODE:
return _compute_moe_deepseek_blog_decode(layer)
else:
raise NotImplementedError(f"Unsupported {forward_mode=}")
def _compute_moe_deepseek_blog_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
tbo_delta_stages=0,
operations=[
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.op_mlp,
layer.mlp.op_gate,
layer.mlp.op_select_experts,
layer.mlp.op_dispatch_a,
operations.YieldOperation(),
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
operations.YieldOperation(),
layer.mlp.op_shared_experts,
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
]
],
)
# Will add TBO operation orders here
return [
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_shared_experts,
layer.mlp.op_select_experts,
layer.mlp.op_dispatch_a,
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
]
def _compute_moe_deepseek_blog_decode(layer):
return OperationsStrategy(
deep_gemm_num_sms=None,
tbo_delta_stages=2,
operations=[
layer.op_comm_prepare_attn,
layer.self_attn.op_prepare,
operations.YieldOperation(),
layer.self_attn.op_core,
layer.op_comm_prepare_mlp,
layer.mlp.op_gate,
layer.mlp.op_select_experts,
operations.YieldOperation(),
layer.mlp.op_dispatch_a,
layer.mlp.op_shared_experts,
operations.YieldOperation(),
layer.mlp.op_dispatch_b,
layer.mlp.op_experts,
layer.mlp.op_combine_a,
operations.YieldOperation(),
layer.mlp.op_combine_b,
layer.mlp.op_output,
layer.op_comm_postprocess_layer,
operations.YieldOperation(),
],
)