From 107710268ad880f9328b27eb3468adf5940c6bb8 Mon Sep 17 00:00:00 2001 From: IAN <50618241+hcyz33@users.noreply.github.com> Date: Wed, 26 Feb 2025 01:32:05 +0800 Subject: [PATCH] [BugFix] Fix crash when receive a req with structed output in DP attention mode. (#3841) --- benchmark/json_decode_regex/bench_sglang.py | 1 + python/sglang/srt/managers/scheduler.py | 42 ++++++++++++++++----- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/benchmark/json_decode_regex/bench_sglang.py b/benchmark/json_decode_regex/bench_sglang.py index 462c77750..4139ebf8a 100644 --- a/benchmark/json_decode_regex/bench_sglang.py +++ b/benchmark/json_decode_regex/bench_sglang.py @@ -46,6 +46,7 @@ def json_decode(s, document): def main(args): lines = read_jsonl(args.data_path) + lines = list(lines) arguments = [] for i in range(len(lines[: args.num_questions])): arguments.append( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4cd69d0ed..e4a141a9c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1154,6 +1154,10 @@ class Scheduler: elif batch.forward_mode.is_idle(): if self.enable_overlap: self.tp_worker.resolve_batch_result(result.bid) + if batch.next_batch_sampling_info: + batch.next_batch_sampling_info.update_regex_vocab_mask() + self.current_stream.synchronize() + batch.next_batch_sampling_info.sampling_info_done.set() elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() @@ -1630,16 +1634,34 @@ class Scheduler: except futures._base.TimeoutError: break - if self.tp_size > 1: - # Sync across TP ranks to make sure they have the same number of ready requests - tensor = torch.tensor(num_ready_reqs, dtype=torch.int32) - torch.distributed.all_reduce( - tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group - ) - num_ready_reqs_max = tensor.item() - for i in range(num_ready_reqs, num_ready_reqs_max): - self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result() - num_ready_reqs = num_ready_reqs_max + if self.server_args.enable_dp_attention: + if self.attn_tp_size > 1: + # Sync across attn TP ranks to make sure they have the same number of ready requests + tensor = torch.tensor(num_ready_reqs, dtype=torch.int32) + torch.distributed.all_reduce( + tensor, + op=torch.distributed.ReduceOp.MAX, + group=self.attn_tp_cpu_group, + ) + num_ready_reqs_max = tensor.item() + for i in range(num_ready_reqs, num_ready_reqs_max): + self.grammar_queue[i].grammar = self.grammar_queue[ + i + ].grammar.result() + num_ready_reqs = num_ready_reqs_max + else: + if self.tp_size > 1: + # Sync across TP ranks to make sure they have the same number of ready requests + tensor = torch.tensor(num_ready_reqs, dtype=torch.int32) + torch.distributed.all_reduce( + tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group + ) + num_ready_reqs_max = tensor.item() + for i in range(num_ready_reqs, num_ready_reqs_max): + self.grammar_queue[i].grammar = self.grammar_queue[ + i + ].grammar.result() + num_ready_reqs = num_ready_reqs_max self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:]