Support overlapping two batches (#4068)

This commit is contained in:
fzyzcjy
2025-05-25 08:39:07 +08:00
committed by GitHub
parent f456037396
commit 0d47788025
13 changed files with 1145 additions and 129 deletions

View File

@@ -34,6 +34,7 @@ import zmq
from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt import two_batch_overlap
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
@@ -132,7 +133,9 @@ from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode,
broadcast_pyobj,
configure_logger,
@@ -1648,6 +1651,9 @@ class Scheduler(
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
)
@staticmethod
@@ -1661,6 +1667,9 @@ class Scheduler(
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
):
# Check if other DP workers have running batches
if local_batch is None:
@@ -1696,17 +1705,26 @@ class Scheduler(
is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)
tbo_preparer = TboDPAttentionPreparer()
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
num_tokens_for_logprob,
is_extend_in_batch,
*tbo_preparer.prepare_all_gather(
local_batch,
deepep_mode,
enable_deepep_moe,
enable_two_batch_overlap,
),
],
dtype=torch.int64,
)
global_info = torch.empty(
(dp_size, attn_tp_size, 4),
(dp_size, attn_tp_size, 6),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
@@ -1719,6 +1737,10 @@ class Scheduler(
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
global_info[:, :, 4:6]
)
if local_batch is None and max(global_num_tokens) > 0:
local_batch = get_idle_batch()
@@ -1732,6 +1754,8 @@ class Scheduler(
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode
# Check forward mode for cuda graph
if not disable_cuda_graph: