Support single batch overlap (#10422)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -38,10 +39,12 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.offloader import get_offloader
|
||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||
from sglang.srt.utils import (
|
||||
ceil_div,
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
@@ -466,7 +469,11 @@ class DeepEPMoE(EPMoE):
|
||||
),
|
||||
)
|
||||
|
||||
def moe_impl(self, dispatch_output: DispatchOutput):
|
||||
def moe_impl(
|
||||
self,
|
||||
dispatch_output: DispatchOutput,
|
||||
down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None,
|
||||
):
|
||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
||||
|
||||
if _use_aiter:
|
||||
@@ -481,7 +488,9 @@ class DeepEPMoE(EPMoE):
|
||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
||||
return self.forward_flashinfer_cutedsl(dispatch_output)
|
||||
return self.forward_flashinfer_cutedsl(
|
||||
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
||||
)
|
||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||
return self.forward_deepgemm_masked(dispatch_output)
|
||||
else:
|
||||
@@ -495,12 +504,14 @@ class DeepEPMoE(EPMoE):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
overlap_args: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
return self.deepep_dispatcher.combine(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
|
||||
def forward_aiter(
|
||||
@@ -687,6 +698,7 @@ class DeepEPMoE(EPMoE):
|
||||
def forward_flashinfer_cutedsl(
|
||||
self,
|
||||
dispatch_output: DeepEPLLOutput,
|
||||
down_gemm_overlap_args: Optional[DownGemmOverlapArgs],
|
||||
):
|
||||
hidden_states, _, _, masked_m, _ = dispatch_output
|
||||
assert self.quant_method is not None
|
||||
@@ -697,6 +709,7 @@ class DeepEPMoE(EPMoE):
|
||||
x=hidden_states,
|
||||
masked_m=masked_m,
|
||||
moe_runner_config=self.moe_runner_config,
|
||||
down_gemm_overlap_args=down_gemm_overlap_args,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ def flashinfer_cutedsl_moe_masked(
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alpha,
|
||||
masked_m: torch.Tensor,
|
||||
down_sm_count: Optional[int] = None,
|
||||
down_signals: Optional[torch.Tensor] = None,
|
||||
down_start_event: Optional[torch.cuda.Event] = None,
|
||||
):
|
||||
"""
|
||||
Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL
|
||||
@@ -151,6 +154,9 @@ def flashinfer_cutedsl_moe_masked(
|
||||
masked_m,
|
||||
)
|
||||
|
||||
if down_start_event is not None:
|
||||
down_start_event.record()
|
||||
|
||||
# Gemm2
|
||||
out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device)
|
||||
out = out.permute(1, 2, 0) # requirement of kernel
|
||||
@@ -165,5 +171,13 @@ def flashinfer_cutedsl_moe_masked(
|
||||
sf_vec_size=sf_vec_size,
|
||||
alpha=w2_alpha.view(1, 1, num_experts),
|
||||
alpha_dtype=get_cute_dtype(w2_alpha),
|
||||
**(
|
||||
dict(
|
||||
sm_count=down_sm_count,
|
||||
dst_signals=down_signals,
|
||||
)
|
||||
if down_sm_count is not None or down_signals is not None
|
||||
else {}
|
||||
),
|
||||
) # in logical [m, k, l]
|
||||
return out.permute(2, 0, 1)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.moe.token_dispatcher.base import (
|
||||
@@ -25,6 +26,9 @@ from sglang.srt.utils import (
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.single_batch_overlap import CombineOverlapArgs
|
||||
|
||||
try:
|
||||
from deep_ep import Buffer, Config
|
||||
|
||||
@@ -310,6 +314,7 @@ class _DeepEPDispatcherImplBase:
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -428,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
deepep_post_reorder_triton_kernel,
|
||||
@@ -503,6 +509,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
||||
"""
|
||||
self.return_recv_hook = return_recv_hook
|
||||
self.device_module = torch.get_device_module()
|
||||
|
||||
def dispatch_a(
|
||||
self,
|
||||
@@ -570,7 +577,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
use_fp8 = True
|
||||
|
||||
buffer = self._get_buffer()
|
||||
packed_recv_hidden, packed_recv_count, self.handle, event, hook = (
|
||||
packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = (
|
||||
buffer.low_latency_dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
@@ -591,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
and deep_gemm_wrapper.DEEPGEMM_BLACKWELL,
|
||||
)
|
||||
)
|
||||
return packed_recv_hidden, packed_recv_count, event, hook
|
||||
return packed_recv_hidden, self.packed_recv_count, event, hook
|
||||
|
||||
def combine_a(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
hidden_states, event, hook = self._combine_core(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
return hidden_states, event, hook
|
||||
return hidden_states, event, hook, overlap_args
|
||||
|
||||
def combine_b(self, hidden_states, event, hook):
|
||||
def combine_b(self, hidden_states, event, hook, overlap_args):
|
||||
hook() if self.return_recv_hook else event.current_stream_wait()
|
||||
|
||||
if overlap_args is not None:
|
||||
self.device_module.current_stream().wait_stream(overlap_args.stream)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _combine_core(
|
||||
@@ -615,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
overlap_args: Optional["CombineOverlapArgs"],
|
||||
):
|
||||
buffer = self._get_buffer()
|
||||
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
self.handle,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
)
|
||||
self.handle = None
|
||||
|
||||
ctx = nullcontext()
|
||||
if overlap_args is not None:
|
||||
overlap_args.stream.wait_event(overlap_args.wait_event)
|
||||
ctx = torch.cuda.stream(overlap_args.stream)
|
||||
|
||||
with ctx:
|
||||
combined_hidden_states, event, hook = buffer.low_latency_combine(
|
||||
x=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
handle=self.handle,
|
||||
async_finish=not self.return_recv_hook,
|
||||
return_recv_hook=self.return_recv_hook,
|
||||
**(
|
||||
dict(
|
||||
overlap=overlap_args.overlap,
|
||||
src_signals=overlap_args.signal,
|
||||
src_signal_expect_value=overlap_args.threshold,
|
||||
)
|
||||
if overlap_args is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
self.packed_recv_count = self.handle = None
|
||||
return combined_hidden_states, event, hook
|
||||
|
||||
def _get_buffer(self):
|
||||
@@ -727,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher):
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
overlap_args: Optional["CombineOverlapArgs"] = None,
|
||||
):
|
||||
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
||||
inner_state = self._get_impl(forward_batch).combine_a(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
overlap_args=overlap_args,
|
||||
)
|
||||
self._combine_intermediate_state = forward_batch, inner_state
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
|
||||
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
||||
DEEPEP_MODE: Optional[DeepEPMode] = None
|
||||
IS_TBO_ENABLED: Optional[bool] = None
|
||||
IS_SBO_ENABLED: Optional[bool] = None
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
||||
DEEPEP_CONFIG: Optional[str] = None
|
||||
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
|
||||
@@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
||||
global DEEPEP_MODE
|
||||
global DEEPEP_CONFIG
|
||||
global IS_TBO_ENABLED
|
||||
global IS_SBO_ENABLED
|
||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
||||
|
||||
@@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs):
|
||||
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
||||
DEEPEP_CONFIG = server_args.deepep_config or ""
|
||||
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
||||
IS_SBO_ENABLED = server_args.enable_single_batch_overlap
|
||||
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
||||
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
|
||||
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
|
||||
@@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool:
|
||||
return IS_TBO_ENABLED
|
||||
|
||||
|
||||
def is_sbo_enabled() -> bool:
|
||||
global IS_SBO_ENABLED
|
||||
if IS_SBO_ENABLED is None:
|
||||
IS_SBO_ENABLED = False
|
||||
return IS_SBO_ENABLED
|
||||
|
||||
|
||||
def get_tbo_token_distribution_threshold() -> float:
|
||||
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
||||
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|
||||
|
||||
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
|
||||
CombineInput,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import scaled_fp4_quant
|
||||
@@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
x: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
moe_runner_config.activation == "silu"
|
||||
@@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alpha=layer.g2_alphas,
|
||||
masked_m=masked_m,
|
||||
**(
|
||||
dict(
|
||||
down_sm_count=down_gemm_overlap_args.num_sms,
|
||||
down_signals=down_gemm_overlap_args.signal,
|
||||
down_start_event=down_gemm_overlap_args.start_event,
|
||||
)
|
||||
if down_gemm_overlap_args is not None
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return out
|
||||
|
||||
@@ -28,6 +28,7 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt import single_batch_overlap
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
@@ -101,6 +102,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.single_batch_overlap import SboFlags
|
||||
from sglang.srt.two_batch_overlap import (
|
||||
MaybeTboDeepEPDispatcher,
|
||||
model_forward_maybe_tbo,
|
||||
@@ -806,7 +808,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
if hidden_states.shape[0] > 0:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
if not SboFlags.fuse_shared_experts_inside_sbo():
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
@@ -820,12 +823,18 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states.device
|
||||
)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
forward_batch=forward_batch,
|
||||
# SBO args
|
||||
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
|
||||
experts=self.experts,
|
||||
alt_stream=self.alt_stream,
|
||||
)
|
||||
if sbo_shared_output is not None:
|
||||
shared_output = sbo_shared_output
|
||||
|
||||
if shared_output is not None:
|
||||
x = shared_output
|
||||
@@ -843,7 +852,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
def _forward_shared_experts(
|
||||
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
||||
):
|
||||
if self.num_fused_shared_experts == 0:
|
||||
if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
|
||||
return self.shared_experts(
|
||||
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
||||
)
|
||||
|
||||
@@ -377,6 +377,7 @@ class ServerArgs:
|
||||
enable_dp_attention: bool = False
|
||||
enable_dp_lm_head: bool = False
|
||||
enable_two_batch_overlap: bool = False
|
||||
enable_single_batch_overlap: bool = False
|
||||
tbo_token_distribution_threshold: float = 0.48
|
||||
enable_torch_compile: bool = False
|
||||
torch_compile_max_bs: int = 32
|
||||
@@ -2457,6 +2458,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enabling two micro batches to overlap.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-single-batch-overlap",
|
||||
action="store_true",
|
||||
help="Let computation and communication overlap within one micro batch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tbo-token-distribution-threshold",
|
||||
type=float,
|
||||
|
||||
151
python/sglang/srt/single_batch_overlap.py
Normal file
151
python/sglang/srt/single_batch_overlap.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.moe import get_moe_runner_backend
|
||||
from sglang.srt.layers.moe.utils import is_sbo_enabled
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import get_int_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
|
||||
|
||||
|
||||
class SboFlags:
|
||||
# TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...
|
||||
|
||||
@classmethod
|
||||
def enable_combine_down_gemm_two_stream_overlap(cls):
|
||||
return (
|
||||
is_sbo_enabled()
|
||||
# currently only cutedsl backend supports it
|
||||
and get_moe_runner_backend().is_flashinfer_cutedsl()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enable_combine_shared_two_stream_overlap(cls):
|
||||
return is_sbo_enabled()
|
||||
|
||||
@classmethod
|
||||
def fuse_shared_experts_inside_sbo(cls):
|
||||
# TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
|
||||
return cls.enable_combine_shared_two_stream_overlap()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CombineOverlapArgs:
|
||||
# this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
|
||||
overlap: bool
|
||||
stream: torch.cuda.Stream
|
||||
wait_event: torch.cuda.Event
|
||||
num_sms: int
|
||||
signal: Optional[torch.Tensor] = None
|
||||
threshold: int = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownGemmOverlapArgs:
|
||||
num_sms: int
|
||||
signal: torch.Tensor
|
||||
start_event: torch.cuda.Event
|
||||
|
||||
|
||||
def execute_sbo(
|
||||
forward_shared_experts: Callable[[], Any],
|
||||
experts: "DeepEPMoE",
|
||||
hidden_states: torch.Tensor,
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
alt_stream: Optional = None,
|
||||
):
|
||||
shared_output = None
|
||||
|
||||
dispatch_output = experts.dispatch(
|
||||
hidden_states, topk_idx, topk_weights, forward_batch
|
||||
)
|
||||
|
||||
combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
|
||||
_compute_overlap_args(dispatch_output, alt_stream)
|
||||
)
|
||||
|
||||
hidden_states = experts.moe_impl(
|
||||
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
||||
)
|
||||
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
||||
e.record()
|
||||
|
||||
if SboFlags.enable_combine_shared_two_stream_overlap():
|
||||
# TODO reduce sm for non-deepgemm
|
||||
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
||||
meta_overlap_args["compute_num_sms"]
|
||||
):
|
||||
shared_output = forward_shared_experts()
|
||||
|
||||
hidden_states = experts.combine(
|
||||
hidden_states,
|
||||
dispatch_output.topk_idx,
|
||||
dispatch_output.topk_weights,
|
||||
forward_batch,
|
||||
overlap_args=combine_overlap_args,
|
||||
)
|
||||
|
||||
return hidden_states, shared_output
|
||||
|
||||
|
||||
def _compute_overlap_args(dispatch_output, alt_stream):
|
||||
if not (
|
||||
SboFlags.enable_combine_down_gemm_two_stream_overlap()
|
||||
or SboFlags.enable_combine_shared_two_stream_overlap()
|
||||
):
|
||||
return None, None, {}
|
||||
|
||||
hidden_states = dispatch_output.hidden_states_fp8
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
|
||||
|
||||
total_num_sms = torch.cuda.get_device_properties(
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
|
||||
compute_num_sms = total_num_sms - communicate_num_sms
|
||||
|
||||
assert alt_stream is not None
|
||||
combine_wait_event = torch.cuda.Event()
|
||||
combine_overlap_args = CombineOverlapArgs(
|
||||
overlap=False,
|
||||
num_sms=communicate_num_sms,
|
||||
stream=alt_stream,
|
||||
wait_event=combine_wait_event,
|
||||
)
|
||||
meta_overlap_args = dict(
|
||||
compute_num_sms=compute_num_sms,
|
||||
)
|
||||
down_gemm_overlap_args = None
|
||||
|
||||
if SboFlags.enable_combine_down_gemm_two_stream_overlap():
|
||||
# TODO use zero_allocator to remove this `torch.zeros` call
|
||||
# NOTE ours v2 use uint32 not int32 currently
|
||||
combine_signal = torch.zeros(
|
||||
num_local_experts, dtype=torch.uint32, device=hidden_states.device
|
||||
)
|
||||
|
||||
down_gemm_overlap_args = DownGemmOverlapArgs(
|
||||
signal=combine_signal,
|
||||
start_event=combine_wait_event,
|
||||
num_sms=compute_num_sms,
|
||||
)
|
||||
combine_overlap_args.overlap = True
|
||||
combine_overlap_args.signal = combine_signal
|
||||
combine_overlap_args.threshold = compute_num_sms
|
||||
else:
|
||||
meta_overlap_args |= dict(
|
||||
record_event_after_down=combine_wait_event,
|
||||
)
|
||||
|
||||
return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
|
||||
Reference in New Issue
Block a user