Support overlapping two batches (#4068)
This commit is contained in:
@@ -78,6 +78,7 @@ global_server_args_dict = {
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
"enable_two_batch_overlap": ServerArgs.enable_two_batch_overlap,
|
||||
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
|
||||
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
||||
"deepep_config": ServerArgs.deepep_config,
|
||||
@@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
tbo_split_seq_index: Optional[int] = None
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
@@ -1624,6 +1627,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
or global_server_args_dict["attention_backend"] == "flashmla"
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = self.seq_lens.cpu()
|
||||
else:
|
||||
@@ -1651,6 +1655,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
tbo_split_seq_index=self.tbo_split_seq_index,
|
||||
global_forward_mode=self.global_forward_mode,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
@@ -1729,6 +1735,8 @@ class ModelWorkerBatch:
|
||||
global_num_tokens: Optional[List[int]]
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
tbo_split_seq_index: Optional[int]
|
||||
global_forward_mode: Optional[ForwardMode]
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int]
|
||||
|
||||
@@ -34,6 +34,7 @@ import zmq
|
||||
from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt import two_batch_overlap
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
||||
from sglang.srt.disaggregation.decode import (
|
||||
@@ -132,7 +133,9 @@ from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
configure_logger,
|
||||
@@ -1648,6 +1651,9 @@ class Scheduler(
|
||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
||||
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
||||
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1661,6 +1667,9 @@ class Scheduler(
|
||||
disable_cuda_graph: bool,
|
||||
spec_algorithm,
|
||||
speculative_num_draft_tokens,
|
||||
enable_two_batch_overlap: bool,
|
||||
enable_deepep_moe: bool,
|
||||
deepep_mode: DeepEPMode,
|
||||
):
|
||||
# Check if other DP workers have running batches
|
||||
if local_batch is None:
|
||||
@@ -1696,17 +1705,26 @@ class Scheduler(
|
||||
is_extend_in_batch = (
|
||||
local_batch.forward_mode.is_extend() if local_batch else False
|
||||
)
|
||||
|
||||
tbo_preparer = TboDPAttentionPreparer()
|
||||
|
||||
local_info = torch.tensor(
|
||||
[
|
||||
num_tokens,
|
||||
can_cuda_graph,
|
||||
num_tokens_for_logprob,
|
||||
is_extend_in_batch,
|
||||
*tbo_preparer.prepare_all_gather(
|
||||
local_batch,
|
||||
deepep_mode,
|
||||
enable_deepep_moe,
|
||||
enable_two_batch_overlap,
|
||||
),
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
global_info = torch.empty(
|
||||
(dp_size, attn_tp_size, 4),
|
||||
(dp_size, attn_tp_size, 6),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
@@ -1719,6 +1737,10 @@ class Scheduler(
|
||||
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
|
||||
is_extend_in_batch = global_info[:, 0, 3].tolist()
|
||||
|
||||
tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
|
||||
global_info[:, :, 4:6]
|
||||
)
|
||||
|
||||
if local_batch is None and max(global_num_tokens) > 0:
|
||||
local_batch = get_idle_batch()
|
||||
|
||||
@@ -1732,6 +1754,8 @@ class Scheduler(
|
||||
local_batch.global_num_tokens_for_logprob = (
|
||||
global_num_tokens_for_logprob
|
||||
)
|
||||
local_batch.tbo_split_seq_index = tbo_split_seq_index
|
||||
local_batch.global_forward_mode = global_forward_mode
|
||||
|
||||
# Check forward mode for cuda graph
|
||||
if not disable_cuda_graph:
|
||||
|
||||
Reference in New Issue
Block a user