diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index e1a69ae15..a1b2d4723 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.scheduler import Scheduler from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams @@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer): req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 reqs.append(req) return input_ids, reqs @@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test( i, : bench_args.cut_len ] req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 return reqs @@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): req.prefix_indices = [] req.fill_ids = req.origin_input_ids req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 reqs.append(req) return reqs @@ -238,6 +242,7 @@ def extend(reqs, model_runner): enable_custom_logit_processor=False, ) batch.prepare_for_extend() + _maybe_prepare_dp_attn_batch(batch, model_runner) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) logits_output = model_runner.forward(forward_batch) @@ -249,6 +254,7 @@ def extend(reqs, model_runner): def decode(input_token_ids, batch, model_runner): batch.output_ids = input_token_ids batch.prepare_for_decode() + _maybe_prepare_dp_attn_batch(batch, model_runner) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) logits_output = model_runner.forward(forward_batch) @@ -256,6 +262,20 @@ def decode(input_token_ids, batch, model_runner): return next_token_ids, logits_output.next_token_logits +def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner): + if model_runner.server_args.enable_dp_attention: + Scheduler.prepare_dp_attn_batch_raw( + batch, + dp_size=model_runner.server_args.dp_size, + attn_tp_size=1, + tp_cpu_group=model_runner.tp_group.cpu_group, + get_idle_batch=None, + disable_cuda_graph=model_runner.server_args.disable_cuda_graph, + spec_algorithm=SpeculativeAlgorithm.NONE, + speculative_num_draft_tokens=None, + ) + + def correctness_test( server_args, port_args, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 533db7d87..8b0573e50 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1466,14 +1466,36 @@ class Scheduler( self.send_to_tokenizer.send_pyobj(HealthCheckOutput()) def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): + return self.prepare_dp_attn_batch_raw( + local_batch, + dp_size=self.server_args.dp_size, + attn_tp_size=self.attn_tp_size, + tp_cpu_group=self.tp_cpu_group, + get_idle_batch=self.get_idle_batch, + disable_cuda_graph=self.server_args.disable_cuda_graph, + spec_algorithm=self.spec_algorithm, + speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, + ) + + @staticmethod + def prepare_dp_attn_batch_raw( + local_batch: ScheduleBatch, + dp_size, + attn_tp_size: int, + tp_cpu_group, + get_idle_batch, + disable_cuda_graph: bool, + spec_algorithm, + speculative_num_draft_tokens, + ): # Check if other DP workers have running batches if local_batch is None: num_tokens = 0 global_num_tokens_for_logprob = 0 elif local_batch.forward_mode.is_decode(): num_tokens = local_batch.batch_size() - if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle(): - num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens + if not spec_algorithm.is_none() and spec_algorithm.is_eagle(): + num_tokens = num_tokens * speculative_num_draft_tokens global_num_tokens_for_logprob = num_tokens else: num_tokens = local_batch.extend_num_tokens @@ -1492,7 +1514,7 @@ class Scheduler( else: can_cuda_graph = 0 - if not self.spec_algorithm.is_none(): + if not spec_algorithm.is_none(): # TODO(sang): Support cuda graph when idle batch is there. if local_batch is None or local_batch.forward_mode.is_idle(): can_cuda_graph = 0 @@ -1510,13 +1532,13 @@ class Scheduler( dtype=torch.int64, ) global_info = torch.empty( - (self.server_args.dp_size, self.attn_tp_size, 4), + (dp_size, attn_tp_size, 4), dtype=torch.int64, ) torch.distributed.all_gather_into_tensor( global_info.flatten(), local_info, - group=self.tp_cpu_group, + group=tp_cpu_group, ) global_num_tokens = global_info[:, 0, 0].tolist() can_cuda_graph = min(global_info[:, 0, 1].tolist()) @@ -1524,14 +1546,14 @@ class Scheduler( is_extend_in_batch = global_info[:, 0, 3].tolist() if local_batch is None and max(global_num_tokens) > 0: - local_batch = self.get_idle_batch() + local_batch = get_idle_batch() if local_batch is not None: local_batch.global_num_tokens = global_num_tokens local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob # Check forward mode for cuda graph - if not self.server_args.disable_cuda_graph: + if not disable_cuda_graph: local_batch.can_run_dp_cuda_graph = can_cuda_graph return local_batch, any(is_extend_in_batch)