From 10d60cd41bb520d2cbd16f577f6d60f578e3ab4a Mon Sep 17 00:00:00 2001 From: u4lr451 Date: Tue, 17 Jun 2025 15:33:28 +0800 Subject: [PATCH] feat: mtp support dp-attention (#6081) Co-authored-by: austindeng Co-authored-by: tianqilin.99 Co-authored-by: Qiaolin Yu Co-authored-by: ch-wan --- .../srt/layers/attention/aiter_backend.py | 7 +- .../srt/layers/attention/base_attn_backend.py | 2 +- .../layers/attention/cutlass_mla_backend.py | 1 + .../attention/flashattention_backend.py | 6 +- .../layers/attention/flashinfer_backend.py | 13 +- .../attention/flashinfer_mla_backend.py | 9 +- .../srt/layers/attention/flashmla_backend.py | 7 +- .../srt/layers/attention/tbo_backend.py | 6 +- .../srt/layers/attention/triton_backend.py | 30 +-- python/sglang/srt/layers/dp_attention.py | 8 + python/sglang/srt/managers/schedule_batch.py | 8 +- python/sglang/srt/managers/scheduler.py | 36 +++- .../srt/model_executor/cuda_graph_runner.py | 36 ++-- .../srt/model_executor/forward_batch_info.py | 23 ++- .../sglang/srt/model_executor/model_runner.py | 2 + python/sglang/srt/models/deepseek_nextn.py | 44 ++--- python/sglang/srt/models/deepseek_v2.py | 7 + .../eagle_draft_cuda_graph_runner.py | 114 ++++++++++-- .../eagle_draft_extend_cuda_graph_runner.py | 107 +++++++++-- python/sglang/srt/speculative/eagle_utils.py | 78 +++++++- python/sglang/srt/speculative/eagle_worker.py | 176 +++++++++++++----- test/srt/test_dp_attention.py | 72 +++++++ 22 files changed, 641 insertions(+), 151 deletions(-) diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index ff7b9dc7e..7e6b9936e 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend): ) def init_cuda_graph_state( - self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, ): self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) if kv_indices_buf is None: @@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend): if not self.skip_prefill: self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), + (max_num_tokens * self.max_context_len), dtype=torch.uint8, device=self.device, ) diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index a38c319c7..bddd7891f 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -19,7 +19,7 @@ class AttentionBackend(ABC): """Init the metadata for a forward pass.""" raise NotImplementedError() - def init_cuda_graph_state(self, max_bs: int): + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): """Init the global shared states for cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index 65fff548e..fcfd648d0 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): def init_cuda_graph_state( self, max_bs: int, + max_num_tokens: int, block_kv_indices: Optional[torch.Tensor] = None, ): if block_kv_indices is None: diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index e5210e88c..9871ca3a8 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend): return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) - def init_cuda_graph_state(self, max_bs: int): + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): """Initialize CUDA graph state for the attention backend. Args: @@ -1999,9 +1999,9 @@ class FlashAttentionMultiStepBackend: for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_forward_metadata(forward_batch) - def init_cuda_graph_state(self, max_bs: int): + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): for i in range(self.speculative_num_steps): - self.attn_backends[i].init_cuda_graph_state(max_bs) + self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) def init_forward_metadata_capture_cuda_graph( self, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 316ad18b0..7d62f7821 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend): ) def init_cuda_graph_state( - self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, ): if kv_indices_buf is None: cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len,), + (max_num_tokens * self.max_context_len,), dtype=torch.int32, device="cuda", ) @@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend): if not self.skip_prefill: self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), + (max_num_tokens * self.max_context_len), dtype=torch.uint8, device="cuda", ) @@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend: self.common_template(forward_batch, kv_indices, call_fn) - def init_cuda_graph_state(self, max_bs: int): + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.cuda_graph_kv_indices = torch.zeros( (self.speculative_num_steps, max_bs * self.max_context_len), dtype=torch.int32, @@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend: for i in range(self.speculative_num_steps): self.attn_backends[i].init_cuda_graph_state( - max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] ) def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index c6beb5820..1b8dc64e5 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend): ) def init_cuda_graph_state( - self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, ): if kv_indices_buf is None: cuda_graph_kv_indices = torch.zeros( @@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend: self.common_template(forward_batch, kv_indices, call_fn) - def init_cuda_graph_state(self, max_bs: int): + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.cuda_graph_kv_indices = torch.zeros( (self.speculative_num_steps, max_bs * self.max_context_len), dtype=torch.int32, @@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend: for i in range(self.speculative_num_steps): self.attn_backends[i].init_cuda_graph_state( - max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] ) def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index c688fd461..d1acb1a58 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): def init_cuda_graph_state( self, max_bs: int, + max_num_tokens: int, block_kv_indices: Optional[torch.Tensor] = None, ): if block_kv_indices is None: @@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend: self.common_template(forward_batch, call_fn) - def init_cuda_graph_state(self, max_bs: int): + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): for i in range(self.speculative_num_steps): - self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None) + self.attn_backends[i].init_cuda_graph_state( + max_bs, max_num_tokens, block_kv_indices=None + ) def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def call_fn(i, forward_batch): diff --git a/python/sglang/srt/layers/attention/tbo_backend.py b/python/sglang/srt/layers/attention/tbo_backend.py index afded3c33..4ad8c5b87 100644 --- a/python/sglang/srt/layers/attention/tbo_backend.py +++ b/python/sglang/srt/layers/attention/tbo_backend.py @@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend): if forward_batch_child.batch_size > 0: child.init_forward_metadata(forward_batch=forward_batch_child) - def init_cuda_graph_state(self, max_bs: int): - self.primary.init_cuda_graph_state(max_bs=max_bs) + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens) for item in self.children: # TODO for children, maybe can provide *smaller* max_bs to optimize - item.init_cuda_graph_state(max_bs=max_bs) + item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens) def init_forward_metadata_capture_cuda_graph( self, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 970ceb999..c46c8cd4d 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend): num_kv_splits = None attn_logits = None attn_lse = None + elif forward_batch.forward_mode.is_draft_extend(): kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( @@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend): ) def init_cuda_graph_state( - self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, ): self.cuda_graph_attn_logits = torch.zeros( - (max_bs, self.num_head, self.max_kv_splits, self.v_head_dim), + (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim), dtype=torch.float32, device=self.device, ) self.cuda_graph_attn_lse = torch.zeros( - (max_bs, self.num_head, self.max_kv_splits), + (max_num_tokens, self.num_head, self.max_kv_splits), dtype=torch.float32, device=self.device, ) self.cuda_graph_num_kv_splits = torch.full( - (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device + (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device ) if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len), + (max_num_tokens * self.max_context_len), dtype=torch.int32, device=self.device, ) @@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend): if not self.skip_prefill: self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), + (max_num_tokens * self.max_context_len), dtype=torch.uint8, device=self.device, ) @@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend): if self.sliding_window_size is not None and self.sliding_window_size > 0: if kv_indices_buf is None: self.cuda_graph_window_kv_indices = torch.zeros( - (max_bs * self.sliding_window_size), + (max_num_tokens * self.sliding_window_size), dtype=torch.int32, device=self.device, ) @@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend): self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf) self.cuda_graph_window_num_kv_splits = torch.full( - (max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device + (max_num_tokens,), + self.max_kv_splits, + dtype=torch.int32, + device=self.device, ) def init_forward_metadata_capture_cuda_graph( @@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend): ) custom_mask = self.cuda_graph_custom_mask + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) mask_indptr = self.mask_indptr[: bs + 1] mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) @@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend: self.common_template(forward_batch, kv_indices, call_fn) - def init_cuda_graph_state(self, max_bs: int): + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.cuda_graph_kv_indices = torch.zeros( - (self.speculative_num_steps, max_bs * self.max_context_len), + (self.speculative_num_steps, max_num_tokens * self.max_context_len), dtype=torch.int32, device=self.device, ) for i in range(self.speculative_num_steps): self.attn_backends[i].init_cuda_graph_state( - max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] ) def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index b1862ff2c..6506aa10b 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -238,6 +238,10 @@ def _dp_gather( assert ( local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between global_tokens and local_tokens not allowed" + if forward_batch.forward_mode.is_draft_extend(): + shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) + local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) + memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False ) @@ -288,6 +292,10 @@ def dp_scatter( assert ( local_tokens.untyped_storage() is not global_tokens.untyped_storage() ), "aliasing between local_tokens and global_tokens not allowed" + if forward_batch.forward_mode.is_draft_extend(): + shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) + local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) + memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 369340553..670293a5f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -862,6 +862,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): global_num_tokens: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None can_run_dp_cuda_graph: bool = False + is_extend_in_batch: bool = False tbo_split_seq_index: Optional[int] = None global_forward_mode: Optional[ForwardMode] = None @@ -1760,11 +1761,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): decoding_reqs=self.decoding_reqs, spec_algorithm=self.spec_algorithm, enable_custom_logit_processor=self.enable_custom_logit_processor, + global_num_tokens=self.global_num_tokens, + global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, + can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, + is_extend_in_batch=self.is_extend_in_batch, ) def __str__(self): return ( - f"ScheduleBatch(forward_mode={self.forward_mode.name}, " + f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " f"#req={(len(self.reqs))})" ) @@ -1833,6 +1838,7 @@ class ModelWorkerBatch: spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None + spec_num_draft_tokens: Optional[int] = None # Overlap event launch_done: Optional[threading.Event] = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6b5a03b82..36e3bd3c2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1350,6 +1350,29 @@ class Scheduler( self.metrics_collector.log_stats(self.stats) self._publish_kv_events() + def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]): + """Coordinate the DP attention batch.""" + + local_info = torch.tensor( + [ + (new_batch is not None), + ], + dtype=torch.int64, + ) + global_info = torch.empty( + (self.server_args.dp_size, self.attn_tp_size, 1), + dtype=torch.int64, + ) + torch.distributed.all_gather_into_tensor( + global_info.flatten(), + local_info, + group=self.tp_cpu_group, + ) + any_new_batch = any( + global_info[:, 0, 0].tolist() + ) # Any DP worker has forward batch + return any_new_batch + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch chunked_req_to_exclude = set() @@ -1383,7 +1406,14 @@ class Scheduler( self.running_batch.merge_batch(self.last_batch) new_batch = self.get_new_batch_prefill() - if new_batch is not None: + + # TODO(ch-wan): minor refactor is needed here to improve readability + any_new_batch = ( + self.server_args.enable_dp_attention + and not self.spec_algorithm.is_none() + and self.coordinate_spec_dp_attn_batch(new_batch) + ) + if new_batch is not None or any_new_batch: # Run prefill first if possible ret = new_batch else: @@ -1732,8 +1762,6 @@ class Scheduler( num_tokens_for_logprob = 0 elif local_batch.forward_mode.is_decode(): num_tokens = local_batch.batch_size() - if not spec_algorithm.is_none() and spec_algorithm.is_eagle(): - num_tokens = num_tokens * speculative_num_draft_tokens num_tokens_for_logprob = num_tokens else: num_tokens = local_batch.extend_num_tokens @@ -1809,6 +1837,7 @@ class Scheduler( local_batch.global_num_tokens_for_logprob = ( global_num_tokens_for_logprob ) + local_batch.is_extend_in_batch = any(is_extend_in_batch) local_batch.tbo_split_seq_index = tbo_split_seq_index local_batch.global_forward_mode = global_forward_mode @@ -1816,6 +1845,7 @@ class Scheduler( if not disable_cuda_graph: local_batch.can_run_dp_cuda_graph = can_cuda_graph + # TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here. return local_batch, any(is_extend_in_batch) def get_idle_batch(self): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1c534845d..cba7e869b 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -242,13 +242,13 @@ class CudaGraphRunner: # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs - if global_server_args_dict["attention_backend"] == "flashmla": - self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) - else: - self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) + self.model_runner.attn_backend.init_cuda_graph_state( + self.max_bs, self.max_num_token + ) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) + # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 self.seq_lens_cpu = torch.full( @@ -323,12 +323,15 @@ class CudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.enable_dp_attention or self.enable_sp_layernorm: - total_global_tokens = sum(forward_batch.global_num_tokens_cpu) - + total_batch_size = ( + sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else sum(forward_batch.global_num_tokens_cpu) + ) is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( - total_global_tokens in self.graphs + total_batch_size in self.graphs if self.disable_padding - else total_global_tokens <= self.max_bs + else total_batch_size <= self.max_bs ) else: is_bs_supported = ( @@ -460,7 +463,7 @@ class CudaGraphRunner: self.global_num_tokens_gpu.copy_( torch.tensor( [ - num_tokens // self.dp_size + (i < bs % self.dp_size) + num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) for i in range(self.dp_size) ], dtype=torch.int32, @@ -605,9 +608,12 @@ class CudaGraphRunner: # Pad if self.enable_dp_attention or self.enable_sp_layernorm: - index = bisect.bisect_left( - self.capture_bs, sum(forward_batch.global_num_tokens_cpu) + total_batch_size = ( + sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else sum(forward_batch.global_num_tokens_cpu) ) + index = bisect.bisect_left(self.capture_bs, total_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] @@ -650,13 +656,13 @@ class CudaGraphRunner: # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, - self.req_pool_indices, - self.seq_lens, + self.req_pool_indices[:bs], + self.seq_lens[:bs], forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, - self.encoder_lens, + self.encoder_lens[:bs] if self.is_encoder_decoder else None, forward_batch.forward_mode, forward_batch.spec_info, - seq_lens_cpu=self.seq_lens_cpu, + seq_lens_cpu=self.seq_lens_cpu[:bs], ) # Store fields diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 97e48c10d..28d8c2bfa 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -320,17 +320,30 @@ class ForwardBatch: # For DP attention if batch.global_num_tokens is not None: - ret.global_num_tokens_cpu = batch.global_num_tokens + + spec_num_draft_tokens = ( + batch.spec_num_draft_tokens + if batch.spec_num_draft_tokens is not None + else 1 + ) + global_num_tokens = [ + x * spec_num_draft_tokens for x in batch.global_num_tokens + ] + global_num_tokens_for_logprob = [ + x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob + ] + + ret.global_num_tokens_cpu = global_num_tokens ret.global_num_tokens_gpu = torch.tensor( - batch.global_num_tokens, dtype=torch.int64 + global_num_tokens, dtype=torch.int64 ).to(device, non_blocking=True) - ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob + ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob ret.global_num_tokens_for_logprob_gpu = torch.tensor( - batch.global_num_tokens_for_logprob, dtype=torch.int64 + global_num_tokens_for_logprob, dtype=torch.int64 ).to(device, non_blocking=True) - sum_len = sum(batch.global_num_tokens) + sum_len = sum(global_num_tokens) ret.gathered_buffer = torch.zeros( (sum_len, model_runner.model_config.hidden_size), dtype=model_runner.dtype, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6ef23af24..ec546280b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -163,6 +163,7 @@ class ModelRunner: logger.addFilter(RankZeroFilter(tp_rank == 0)) self.tp_rank = tp_rank self.tp_size = tp_size + self.dp_size = server_args.dp_size self.pp_rank = pp_rank self.pp_size = pp_size self.dist_port = nccl_port @@ -196,6 +197,7 @@ class ModelRunner: | { # TODO it is indeed not a "server args" "use_mla_backend": self.use_mla_backend, + "speculative_algorithm": self.spec_algorithm, } ) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index ce31f355c..8d7bc7b18 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -22,7 +22,6 @@ from transformers import PretrainedConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( @@ -77,6 +76,7 @@ class DeepseekModelNextN(nn.Module): forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, ) -> torch.Tensor: + zero_allocator = BumpAllocator( buffer_size=2, dtype=torch.float32, @@ -90,15 +90,16 @@ class DeepseekModelNextN(nn.Module): else: hidden_states = input_embeds - hidden_states = self.eh_proj( - torch.cat( - ( - self.enorm(hidden_states), - self.hnorm(forward_batch.spec_info.hidden_states), - ), - dim=-1, + if hidden_states.shape[0] > 0: + hidden_states = self.eh_proj( + torch.cat( + ( + self.enorm(hidden_states), + self.hnorm(forward_batch.spec_info.hidden_states), + ), + dim=-1, + ) ) - ) residual = None hidden_states, residual = self.decoder( @@ -127,23 +128,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): self.model = DeepseekModelNextN( config, quant_config, prefix=add_prefix("model", prefix) ) - - if global_server_args_dict["enable_dp_attention"]: - self.lm_head = ReplicatedLinear( - config.hidden_size, - config.vocab_size, - bias=False, - prefix=add_prefix("model.shared_head.head", prefix), - ) - self.logits_processor = LogitsProcessor(config, skip_all_gather=True) - else: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("model.shared_head.head", prefix), - ) - self.logits_processor = LogitsProcessor(config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("model.shared_head.head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) @torch.no_grad() def forward( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 453a2f393..041ecc18c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1399,7 +1399,9 @@ class DeepseekV2DecoderLayer(nn.Module): rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] + self.speculative_algorithm = global_server_args_dict["speculative_algorithm"] self.layer_id = layer_id + self.is_nextn = is_nextn self.self_attn = DeepseekV2AttentionMLA( config=config, hidden_size=self.hidden_size, @@ -1500,6 +1502,11 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states, residual, forward_batch ) + if self.enable_dp_attention and self.speculative_algorithm.is_eagle(): + # NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks. + # See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251). + hidden_states = hidden_states.clone() + return hidden_states, residual def op_comm_prepare_attn( diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 5e70799e5..55ec71562 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -38,6 +38,10 @@ class EAGLEDraftCudaGraphRunner: self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder + self.enable_dp_attention = model_runner.server_args.enable_dp_attention + self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm + self.dp_size = self.model_runner.dp_size self.tp_size = self.model_runner.tp_size self.topk = model_runner.server_args.speculative_eagle_topk self.speculative_num_steps = model_runner.server_args.speculative_num_steps @@ -53,7 +57,9 @@ class EAGLEDraftCudaGraphRunner: # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs - self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token) + self.model_runner.draft_attn_backend.init_cuda_graph_state( + self.max_bs, self.max_num_token + ) self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ 0 ].get_cuda_graph_seq_len_fill_value() @@ -78,10 +84,26 @@ class EAGLEDraftCudaGraphRunner: self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) self.hidden_states = torch.zeros( - (self.max_num_token, self.model_runner.model_config.hidden_size), + (self.max_bs, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype, ) + if self.enable_dp_attention or self.enable_sp_layernorm: + # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + self.global_num_tokens_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + self.global_num_tokens_for_logprob_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + # Capture try: with model_capture_mode(): @@ -92,11 +114,26 @@ class EAGLEDraftCudaGraphRunner: ) def can_run(self, forward_batch: ForwardBatch): - is_bs_supported = ( - forward_batch.batch_size in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) + if self.enable_dp_attention: + # TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head + if not forward_batch.can_run_dp_cuda_graph: + return False + total_batch_size = ( + sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else sum(forward_batch.global_num_tokens_cpu) + ) + is_bs_supported = ( + total_batch_size in self.graphs + if self.disable_padding + else total_batch_size <= self.max_bs + ) + else: + is_bs_supported = ( + forward_batch.batch_size in self.graphs + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) return is_bs_supported def capture(self): @@ -116,8 +153,40 @@ class EAGLEDraftCudaGraphRunner: topk_index = self.topk_index[:num_seqs] hidden_states = self.hidden_states[:num_seqs] + if self.enable_dp_attention or self.enable_sp_layernorm: + self.global_num_tokens_gpu.copy_( + torch.tensor( + [ + num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) + for i in range(self.dp_size) + ], + dtype=torch.int32, + device=self.input_ids.device, + ) + ) + self.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [ + num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) + for i in range(self.dp_size) + ], + dtype=torch.int32, + device=self.input_ids.device, + ) + ) + global_num_tokens = self.global_num_tokens_gpu + gathered_buffer = self.gathered_buffer[:num_tokens] + global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu + else: + global_num_tokens = None + gathered_buffer = None + global_num_tokens_for_logprob = None + spec_info = EagleDraftInput( - topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states + topk_p=topk_p, + topk_index=topk_index, + hidden_states=hidden_states, + capture_hidden_mode=CaptureHiddenMode.LAST, ) # Forward batch @@ -133,11 +202,14 @@ class EAGLEDraftCudaGraphRunner: seq_lens_sum=seq_lens.sum().item(), return_logprob=False, positions=positions, + global_num_tokens_gpu=global_num_tokens, + gathered_buffer=gathered_buffer, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=( spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL ), + global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob, ) # Attention backend @@ -147,6 +219,9 @@ class EAGLEDraftCudaGraphRunner: # Run and capture def run_once(): + # Clean intermediate result cache for DP attention + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + # Backup two fields, which will be modified in-place in `draft_forward`. output_cache_loc_backup = forward_batch.out_cache_loc hidden_states_backup = forward_batch.spec_info.hidden_states @@ -184,7 +259,15 @@ class EAGLEDraftCudaGraphRunner: raw_num_token = raw_bs * self.num_tokens_per_bs # Pad - index = bisect.bisect_left(self.capture_bs, raw_bs) + if self.enable_dp_attention or self.enable_sp_layernorm: + total_batch_size = ( + sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else sum(forward_batch.global_num_tokens_cpu) + ) + index = bisect.bisect_left(self.capture_bs, total_batch_size) + else: + index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: self.seq_lens.fill_(self.seq_len_fill_value) @@ -203,6 +286,13 @@ class EAGLEDraftCudaGraphRunner: self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + if self.enable_dp_attention or self.enable_sp_layernorm: + self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) + self.global_num_tokens_for_logprob_gpu.copy_( + forward_batch.global_num_tokens_for_logprob_gpu + ) + forward_batch.gathered_buffer = self.gathered_buffer + # Attention backend if bs != raw_bs: forward_batch.batch_size = bs @@ -210,8 +300,10 @@ class EAGLEDraftCudaGraphRunner: forward_batch.req_pool_indices = self.req_pool_indices[:bs] forward_batch.positions = self.positions[:num_tokens] - if forward_batch.seq_lens_cpu is not None and bs != raw_bs: - self.seq_lens_cpu.fill_(self.seq_len_fill_value) + # Special handle for seq_len_cpu used when flashinfer mla is used + if forward_batch.seq_lens_cpu is not None: + if bs != raw_bs: + self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 2d2fce197..4c7eb81f7 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -35,6 +35,8 @@ class EAGLEDraftExtendCudaGraphRunner: self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.enable_dp_attention = model_runner.server_args.enable_dp_attention + self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm self.tp_size = self.model_runner.tp_size self.dp_size = model_runner.server_args.dp_size self.speculative_num_steps = model_runner.server_args.speculative_num_steps @@ -51,7 +53,7 @@ class EAGLEDraftExtendCudaGraphRunner: self.max_num_token = self.max_bs * self.num_tokens_per_bs self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state( - self.max_num_token + self.max_bs, self.max_num_token ) self.seq_len_fill_value = ( self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value() @@ -90,6 +92,21 @@ class EAGLEDraftExtendCudaGraphRunner: (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 ) + if self.enable_dp_attention or self.enable_sp_layernorm: + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + self.global_num_tokens_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + self.global_num_tokens_for_logprob_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + # Capture try: with model_capture_mode(): @@ -100,15 +117,30 @@ class EAGLEDraftExtendCudaGraphRunner: ) def can_run(self, forward_batch: ForwardBatch): - batch_size = forward_batch.seq_lens.numel() + if self.enable_dp_attention or self.enable_sp_layernorm: + if not forward_batch.can_run_dp_cuda_graph: + return False + total_batch_size = ( + sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else sum(forward_batch.global_num_tokens_cpu) + ) + is_bs_supported = ( + total_batch_size in self.graphs + if self.disable_padding + else total_batch_size <= self.max_bs + ) + return is_bs_supported + else: + batch_size = forward_batch.seq_lens.numel() - is_bs_supported = ( - batch_size in self.graphs - if self.disable_padding - else batch_size <= self.max_bs - ) + is_bs_supported = ( + batch_size in self.graphs + if self.disable_padding + else batch_size <= self.max_bs + ) - return is_bs_supported + return is_bs_supported def capture(self): CudaGraphRunner.capture(self) @@ -128,6 +160,35 @@ class EAGLEDraftExtendCudaGraphRunner: positions = self.positions[:num_tokens] hidden_states = self.hidden_states[:num_tokens] + if self.enable_dp_attention or self.enable_sp_layernorm: + self.global_num_tokens_gpu.copy_( + torch.tensor( + [ + num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) + for i in range(self.dp_size) + ], + dtype=torch.int32, + device=self.input_ids.device, + ) + ) + self.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [ + num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) + for i in range(self.dp_size) + ], + dtype=torch.int32, + device=self.input_ids.device, + ) + ) + global_num_tokens = self.global_num_tokens_gpu + gathered_buffer = self.gathered_buffer[:num_tokens] + global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu + else: + global_num_tokens = None + gathered_buffer = None + global_num_tokens_for_logprob = None + spec_info = EagleDraftInput( hidden_states=hidden_states, accept_length=accept_length, @@ -147,6 +208,9 @@ class EAGLEDraftExtendCudaGraphRunner: seq_lens_sum=seq_lens.sum().item(), return_logprob=False, positions=positions, + global_num_tokens_gpu=global_num_tokens, + global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob, + gathered_buffer=gathered_buffer, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=CaptureHiddenMode.LAST, @@ -167,6 +231,9 @@ class EAGLEDraftExtendCudaGraphRunner: # Run and capture def run_once(): + # Clean intermediate result cache for DP attention + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + # Backup two fields, which will be modified in-place in `draft_forward`. output_cache_loc_backup = forward_batch.out_cache_loc hidden_states_backup = forward_batch.spec_info.hidden_states @@ -203,24 +270,42 @@ class EAGLEDraftExtendCudaGraphRunner: # in the batch, which will not be counted as num_seqs raw_bs = forward_batch.batch_size num_tokens = forward_batch.input_ids.shape[0] + if self.enable_dp_attention or self.enable_sp_layernorm: + total_batch_size = ( + sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else sum(forward_batch.global_num_tokens_cpu) + ) + index = bisect.bisect_left(self.capture_bs, total_batch_size) + else: + index = bisect.bisect_left(self.capture_bs, raw_bs) - index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs * self.num_tokens_per_bs != num_tokens: self.seq_lens.fill_(self.seq_len_fill_value) self.out_cache_loc.zero_() self.accept_length.fill_(1) + self.extend_seq_lens.fill_(1) # Common inputs self.input_ids[:num_tokens].copy_(forward_batch.input_ids) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) + if forward_batch.extend_seq_lens is not None: + self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) self.positions[:num_tokens].copy_(forward_batch.positions) self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) - self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) + if forward_batch.spec_info.accept_length is not None: + self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + if self.enable_dp_attention or self.enable_sp_layernorm: + self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) + self.global_num_tokens_for_logprob_gpu.copy_( + forward_batch.global_num_tokens_for_logprob_gpu + ) + forward_batch.gathered_buffer = self.gathered_buffer + if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: self.seq_lens_cpu.fill_(self.seq_len_fill_value) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 171a0327e..b69e2939c 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -25,6 +25,8 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 +logger = logging.getLogger(__name__) + if is_cuda(): from sgl_kernel import ( fast_topk, @@ -69,6 +71,8 @@ class EagleDraftInput: kv_indices: torch.Tensor = None def prepare_for_extend(self, batch: ScheduleBatch): + if batch.forward_mode.is_idle(): + return # Prefill only generate 1 token. assert len(self.verified_id) == len(batch.seq_lens) @@ -80,6 +84,24 @@ class EagleDraftInput: ) pt += extend_len + @classmethod + def create_idle_input( + cls, + device: torch.device, + hidden_size: int, + topk: int, + capture_hidden_mode: CaptureHiddenMode, + ): + return cls( + verified_id=None, + hidden_states=torch.empty( + (0, hidden_size), device=device, dtype=torch.float32 + ), + topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), + topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), + capture_hidden_mode=capture_hidden_mode, + ) + def prepare_extend_after_decode( self, batch: ScheduleBatch, @@ -193,7 +215,35 @@ class EagleVerifyInput: seq_lens_cpu: torch.Tensor grammar: BaseGrammarObject = None + @classmethod + def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int): + return cls( + draft_token=torch.empty((0,), dtype=torch.long, device="cuda"), + custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"), + positions=torch.empty((0,), dtype=torch.int64, device="cuda"), + retrive_index=torch.full( + (0, num_verify_tokens), -1, dtype=torch.long, device="cuda" + ), + retrive_next_token=torch.full( + (0, num_verify_tokens), -1, dtype=torch.long, device="cuda" + ), + retrive_next_sibling=torch.full( + (0, num_verify_tokens), -1, dtype=torch.long, device="cuda" + ), + retrive_cum_len=None, + topk=topk, + draft_token_num=num_verify_tokens, + spec_steps=spec_steps, + capture_hidden_mode=CaptureHiddenMode.FULL, + seq_lens_sum=0, + seq_lens_cpu=torch.empty((0,), dtype=torch.int32), + ) + def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): + + if batch.forward_mode.is_idle(): + return + batch.input_ids = self.draft_token if page_size == 1: @@ -279,6 +329,25 @@ class EagleVerifyInput: tokens. I.e., logits_output.next_token_logits only contains accepted token logits. """ + if batch.forward_mode.is_idle(): + return EagleVerifyOutput( + draft_input=EagleDraftInput.create_idle_input( + device=batch.device, + hidden_size=batch.model_config.hidden_size, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ), + logits_output=logits_output, + verified_id=torch.empty(0, dtype=torch.long, device=batch.device), + accept_length_per_req_cpu=[], + accepted_indices=torch.full( + (0, self.spec_steps + 1), + -1, + dtype=torch.int32, + device=batch.device, + ), + ) + bs = self.retrive_index.shape[0] candidates = self.draft_token.reshape(bs, self.draft_token_num) sampling_info = batch.sampling_info @@ -992,10 +1061,11 @@ def select_top_k_tokens( topk_index = topk_index.reshape(-1, topk**2) input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() - selected_input_index = topk_cs_index.flatten() // topk + torch.arange( - 0, hidden_states.shape[0], step=topk, device="cuda" - ).repeat_interleave(topk) - hidden_states = hidden_states[selected_input_index, :] + if hidden_states.shape[0] > 0: + selected_input_index = topk_cs_index.flatten() // topk + torch.arange( + 0, hidden_states.shape[0], step=topk, device="cuda" + ).repeat_interleave(topk) + hidden_states = hidden_states[selected_input_index, :] tree_info = ( expand_scores, # shape: (b, topk, topk) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index c9157c225..6d0482f46 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -7,8 +7,12 @@ from typing import List, Optional, Tuple import torch from huggingface_hub import snapshot_download -from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group -from sglang.srt.layers.dp_attention import disable_dp_size +from sglang.srt.distributed import ( + GroupCoordinator, + get_tensor_model_parallel_world_size, + get_tp_group, + patch_tensor_parallel_group, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.managers.schedule_batch import ( @@ -57,7 +61,7 @@ logger = logging.getLogger(__name__) def draft_tp_context(tp_group: GroupCoordinator): # Draft model doesn't use dp and has its own tp group. # We disable mscclpp now because it doesn't support 2 comm groups. - with disable_dp_size(), patch_tensor_parallel_group(tp_group): + with patch_tensor_parallel_group(tp_group): yield @@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker): self.server_args = server_args self.topk = server_args.speculative_eagle_topk self.speculative_num_steps = server_args.speculative_num_steps + self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens self.enable_nan_detection = server_args.enable_nan_detection self.gpu_id = gpu_id self.device = server_args.device @@ -302,32 +307,7 @@ class EAGLEWorker(TpModelWorker): A tuple of the final logit output of the target model, next tokens accepted, the batch id (used for overlap schedule), and number of accepted tokens. """ - if batch.forward_mode.is_decode(): - with self.draft_tp_context(self.draft_model_runner.tp_group): - spec_info = self.draft(batch) - logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( - self.verify(batch, spec_info) - ) - - # If it is None, it means all requests are finished - if batch.spec_info.verified_id is not None: - with self.draft_tp_context(self.draft_model_runner.tp_group): - self.forward_draft_extend_after_decode(batch) - return ( - logits_output, - verify_output.verified_id, - model_worker_batch.bid, - sum(verify_output.accept_length_per_req_cpu), - can_run_cuda_graph, - ) - elif batch.forward_mode.is_idle(): - model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids, _ = ( - self.target_worker.forward_batch_generation(model_worker_batch) - ) - - return logits_output, next_token_ids, model_worker_batch.bid, 0, False - else: + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: logits_output, next_token_ids, bid, seq_lens_cpu = ( self.forward_target_extend(batch) ) @@ -336,6 +316,51 @@ class EAGLEWorker(TpModelWorker): batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu ) return logits_output, next_token_ids, bid, 0, False + else: + with self.draft_tp_context(self.draft_model_runner.tp_group): + spec_info = self.draft(batch) + logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( + self.verify(batch, spec_info) + ) + need_forward, can_run_draft_extend_cuda_graph = ( + self.check_forward_draft_extend_after_decode(batch) + ) + if need_forward: + with self.draft_tp_context(self.draft_model_runner.tp_group): + self.forward_draft_extend_after_decode( + batch, can_run_draft_extend_cuda_graph + ) + return ( + logits_output, + verify_output.verified_id, + model_worker_batch.bid, + sum(verify_output.accept_length_per_req_cpu), + can_run_cuda_graph, + ) + + def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): + local_need_forward = ( + batch.spec_info.verified_id is not None + and batch.spec_info.verified_id.shape[0] > 0 + ) + if not self.server_args.enable_dp_attention: + return local_need_forward, True + + global_need_forward = torch.tensor( + [ + (local_need_forward), + ], + dtype=torch.int64, + ) + torch.distributed.all_reduce( + global_need_forward, group=get_tp_group().cpu_group + ) + global_need_forward_cnt = global_need_forward[0].item() + need_forward = global_need_forward_cnt > 0 + can_run_draft_extend_cuda_graph = ( + global_need_forward_cnt == get_tensor_model_parallel_world_size() + ) + return need_forward, can_run_draft_extend_cuda_graph def forward_target_extend( self, batch: ScheduleBatch @@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker): # We need the full hidden states to prefill the KV cache of the draft model. model_worker_batch = batch.get_model_worker_batch() model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + model_worker_batch.spec_num_draft_tokens = 1 logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( model_worker_batch ) @@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker): model_worker_batch.seq_lens_cpu, ) - def draft(self, batch: ScheduleBatch): + def _draft_preprocess_decode(self, batch: ScheduleBatch): # Parse args num_seqs = batch.batch_size() spec_info = batch.spec_info @@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker): batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.return_hidden_states = False spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) + self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup) + + def _draft_preprocess_idle(self, batch: ScheduleBatch): + batch.spec_info = EagleDraftInput.create_idle_input( + device=self.device, + hidden_size=self.model_config.hidden_size, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) + + def draft(self, batch: ScheduleBatch): + # Parse args + if batch.forward_mode.is_idle(): + self._draft_preprocess_idle(batch) + else: + self._draft_preprocess_decode(batch) + + spec_info = batch.spec_info + spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + batch.return_hidden_states = False # Get forward batch model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.spec_num_draft_tokens = self.topk + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker): forward_batch ) else: - # Initialize attention backend - self.draft_attn_backend.init_forward_metadata(forward_batch) + if not forward_batch.forward_mode.is_idle(): + # Initialize attention backend + self.draft_attn_backend.init_forward_metadata(forward_batch) # Run forward steps score_list, token_list, parents_list = self.draft_forward(forward_batch) - self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup) + if batch.forward_mode.is_idle(): + return EagleVerifyInput.create_idle_input( + self.topk, + self.speculative_num_steps, + self.speculative_num_draft_tokens, + ) ( tree_mask, @@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker): batch.seq_lens_sum, self.topk, self.speculative_num_steps, - self.server_args.speculative_num_draft_tokens, + self.speculative_num_draft_tokens, ) return EagleVerifyInput( @@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): spec_info.prepare_for_verify(batch, self.page_size) batch.return_hidden_states = False - batch.forward_mode = ForwardMode.TARGET_VERIFY + batch.forward_mode = ( + ForwardMode.TARGET_VERIFY + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu ) + model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode if batch.has_grammar: @@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker): self.add_logprob_values(batch, res, logits_output) # Prepare the batch for the next draft forwards. - batch.forward_mode = ForwardMode.DECODE + batch.forward_mode = ( + ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE + ) batch.spec_info = res.draft_input return logits_output, res, model_worker_batch, can_run_cuda_graph @@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker): model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=seq_lens_cpu ) + model_worker_batch.spec_num_draft_tokens = 1 forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker): assert forward_batch.spec_info is batch.spec_info self.capture_for_decode(logits_output, forward_batch.spec_info) - def forward_draft_extend_after_decode(self, batch: ScheduleBatch): + def forward_draft_extend_after_decode( + self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool + ): # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() req_pool_indices_backup = batch.req_pool_indices accept_length_backup = batch.spec_info.accept_length return_logprob_backup = batch.return_logprob - # Prepare metadata - batch.spec_info.prepare_extend_after_decode( - batch, - self.speculative_num_steps, - ) + input_is_idle = batch.forward_mode.is_idle() + if not input_is_idle: + # Prepare metadata + if batch.spec_info.verified_id is not None: + batch.spec_info.prepare_extend_after_decode( + batch, + self.speculative_num_steps, + ) + else: + batch = batch.copy() + batch.prepare_for_idle() + batch.spec_info = EagleDraftInput.create_idle_input( + device=self.device, + hidden_size=self.model_config.hidden_size, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) + + batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker): # Run can_cuda_graph = ( - self.cuda_graph_runner_for_draft_extend + can_run_draft_extend_cuda_graph + and self.cuda_graph_runner_for_draft_extend and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) ) if can_cuda_graph: @@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker): ) forward_batch.spec_info.hidden_states = logits_output.hidden_states else: - self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch) + if not forward_batch.forward_mode.is_idle(): + self.draft_model_runner.attn_backend.init_forward_metadata( + forward_batch + ) logits_output = self.draft_model_runner.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) @@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker): # Restore backup. # This is because `seq_lens` can be modified in `prepare_extend_after_decode` - batch.forward_mode = ForwardMode.DECODE + batch.forward_mode = ( + ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE + ) batch.seq_lens = seq_lens_backup batch.req_pool_indices = req_pool_indices_backup batch.spec_info.accept_length = accept_length_backup diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index b47fe2c46..af50dc780 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -1,13 +1,19 @@ import unittest from types import SimpleNamespace +import requests + from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + is_in_amd_ci, popen_launch_server, ) @@ -65,5 +71,71 @@ class TestDPAttentionDP2TP2(CustomTestCase): self.assertGreater(metrics["score"], 0.8) +class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--disable-radix", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "2", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "4", + "--speculative-draft", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, + "--tp-size", + "2", + "--enable-dp-attention", + "--dp-size", + "2", + ] + if not is_in_amd_ci(): + other_args += ["--mem-frac", "0.7"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.60) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print( + f"###test_gsm8k (deepseek-v3 mtp + dp):\n" + f"accuracy={metrics['accuracy']=:.3f}\n" + f"{avg_spec_accept_length=:.3f}\n" + ) + self.assertGreater(avg_spec_accept_length, 2.5) + + if __name__ == "__main__": unittest.main()