feat: mtp support dp-attention (#6081)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-17 15:33:28 +08:00
committed by GitHub
parent 8a10c4c3d9
commit 10d60cd41b
22 changed files with 641 additions and 151 deletions

View File

@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
): ):
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int) self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
if kv_indices_buf is None: if kv_indices_buf is None:
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
if not self.skip_prefill: if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros( self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.uint8, dtype=torch.uint8,
device=self.device, device=self.device,
) )

View File

@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
raise NotImplementedError() raise NotImplementedError()
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Init the global shared states for cuda graph.""" """Init the global shared states for cuda graph."""
raise NotImplementedError() raise NotImplementedError()

View File

@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None, block_kv_indices: Optional[torch.Tensor] = None,
): ):
if block_kv_indices is None: if block_kv_indices is None:

View File

@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend. """Initialize CUDA graph state for the attention backend.
Args: Args:
@@ -1999,9 +1999,9 @@ class FlashAttentionMultiStepBackend:
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch) self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs) self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,

View File

@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
): ):
if kv_indices_buf is None: if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros( cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,), (max_num_tokens * self.max_context_len,),
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
if not self.skip_prefill: if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros( self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.uint8, dtype=torch.uint8,
device="cuda", device="cuda",
) )
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn) self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len), (self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):

View File

@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
): ):
if kv_indices_buf is None: if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros( cuda_graph_kv_indices = torch.zeros(
@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn) self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len), (self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):

View File

@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None, block_kv_indices: Optional[torch.Tensor] = None,
): ):
if block_kv_indices is None: if block_kv_indices is None:
@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
self.common_template(forward_batch, call_fn) self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None) self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch): def call_fn(i, forward_batch):

View File

@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
if forward_batch_child.batch_size > 0: if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child) child.init_forward_metadata(forward_batch=forward_batch_child)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.primary.init_cuda_graph_state(max_bs=max_bs) self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
for item in self.children: for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize # TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs) item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,

View File

@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
num_kv_splits = None num_kv_splits = None
attn_logits = None attn_logits = None
attn_lse = None attn_lse = None
elif forward_batch.forward_mode.is_draft_extend(): elif forward_batch.forward_mode.is_draft_extend():
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
) )
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
): ):
self.cuda_graph_attn_logits = torch.zeros( self.cuda_graph_attn_logits = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim), (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
self.cuda_graph_attn_lse = torch.zeros( self.cuda_graph_attn_lse = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits), (max_num_tokens, self.num_head, self.max_kv_splits),
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
self.cuda_graph_num_kv_splits = torch.full( self.cuda_graph_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
) )
if kv_indices_buf is None: if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
if not self.skip_prefill: if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros( self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len), (max_num_tokens * self.max_context_len),
dtype=torch.uint8, dtype=torch.uint8,
device=self.device, device=self.device,
) )
@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
if self.sliding_window_size is not None and self.sliding_window_size > 0: if self.sliding_window_size is not None and self.sliding_window_size > 0:
if kv_indices_buf is None: if kv_indices_buf is None:
self.cuda_graph_window_kv_indices = torch.zeros( self.cuda_graph_window_kv_indices = torch.zeros(
(max_bs * self.sliding_window_size), (max_num_tokens * self.sliding_window_size),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf) self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
self.cuda_graph_window_num_kv_splits = torch.full( self.cuda_graph_window_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device (max_num_tokens,),
self.max_kv_splits,
dtype=torch.int32,
device=self.device,
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
) )
custom_mask = self.cuda_graph_custom_mask custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1] mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn) self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len), (self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
) )
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):

View File

@@ -238,6 +238,10 @@ def _dp_gather(
assert ( assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed" ), "aliasing between global_tokens and local_tokens not allowed"
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton( memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
) )
@@ -288,6 +292,10 @@ def dp_scatter(
assert ( assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed" ), "aliasing between local_tokens and global_tokens not allowed"
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton( memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
) )

View File

@@ -862,6 +862,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False can_run_dp_cuda_graph: bool = False
is_extend_in_batch: bool = False
tbo_split_seq_index: Optional[int] = None tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None global_forward_mode: Optional[ForwardMode] = None
@@ -1760,11 +1761,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
decoding_reqs=self.decoding_reqs, decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor, enable_custom_logit_processor=self.enable_custom_logit_processor,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
) )
def __str__(self): def __str__(self):
return ( return (
f"ScheduleBatch(forward_mode={self.forward_mode.name}, " f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
f"#req={(len(self.reqs))})" f"#req={(len(self.reqs))})"
) )
@@ -1833,6 +1838,7 @@ class ModelWorkerBatch:
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
spec_num_draft_tokens: Optional[int] = None
# Overlap event # Overlap event
launch_done: Optional[threading.Event] = None launch_done: Optional[threading.Event] = None

View File

@@ -1350,6 +1350,29 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
self._publish_kv_events() self._publish_kv_events()
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
"""Coordinate the DP attention batch."""
local_info = torch.tensor(
[
(new_batch is not None),
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 1),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
any_new_batch = any(
global_info[:, 0, 0].tolist()
) # Any DP worker has forward batch
return any_new_batch
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch
chunked_req_to_exclude = set() chunked_req_to_exclude = set()
@@ -1383,7 +1406,14 @@ class Scheduler(
self.running_batch.merge_batch(self.last_batch) self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill() new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# TODO(ch-wan): minor refactor is needed here to improve readability
any_new_batch = (
self.server_args.enable_dp_attention
and not self.spec_algorithm.is_none()
and self.coordinate_spec_dp_attn_batch(new_batch)
)
if new_batch is not None or any_new_batch:
# Run prefill first if possible # Run prefill first if possible
ret = new_batch ret = new_batch
else: else:
@@ -1732,8 +1762,6 @@ class Scheduler(
num_tokens_for_logprob = 0 num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode(): elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size() num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
num_tokens_for_logprob = num_tokens num_tokens_for_logprob = num_tokens
else: else:
num_tokens = local_batch.extend_num_tokens num_tokens = local_batch.extend_num_tokens
@@ -1809,6 +1837,7 @@ class Scheduler(
local_batch.global_num_tokens_for_logprob = ( local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob global_num_tokens_for_logprob
) )
local_batch.is_extend_in_batch = any(is_extend_in_batch)
local_batch.tbo_split_seq_index = tbo_split_seq_index local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode local_batch.global_forward_mode = global_forward_mode
@@ -1816,6 +1845,7 @@ class Scheduler(
if not disable_cuda_graph: if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph local_batch.can_run_dp_cuda_graph = can_cuda_graph
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return local_batch, any(is_extend_in_batch) return local_batch, any(is_extend_in_batch)
def get_idle_batch(self): def get_idle_batch(self):

View File

@@ -242,13 +242,13 @@ class CudaGraphRunner:
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
if global_server_args_dict["attention_backend"] == "flashmla": self.model_runner.attn_backend.init_cuda_graph_state(
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) self.max_bs, self.max_num_token
else: )
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = ( self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0 self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full( self.seq_lens_cpu = torch.full(
@@ -323,12 +323,15 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm: if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu) total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
total_global_tokens in self.graphs total_batch_size in self.graphs
if self.disable_padding if self.disable_padding
else total_global_tokens <= self.max_bs else total_batch_size <= self.max_bs
) )
else: else:
is_bs_supported = ( is_bs_supported = (
@@ -460,7 +463,7 @@ class CudaGraphRunner:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [
num_tokens // self.dp_size + (i < bs % self.dp_size) num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size) for i in range(self.dp_size)
], ],
dtype=torch.int32, dtype=torch.int32,
@@ -605,9 +608,12 @@ class CudaGraphRunner:
# Pad # Pad
if self.enable_dp_attention or self.enable_sp_layernorm: if self.enable_dp_attention or self.enable_sp_layernorm:
index = bisect.bisect_left( total_batch_size = (
self.capture_bs, sum(forward_batch.global_num_tokens_cpu) sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
) )
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else: else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
@@ -650,13 +656,13 @@ class CudaGraphRunner:
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, bs,
self.req_pool_indices, self.req_pool_indices[:bs],
self.seq_lens, self.seq_lens[:bs],
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens, self.encoder_lens[:bs] if self.is_encoder_decoder else None,
forward_batch.forward_mode, forward_batch.forward_mode,
forward_batch.spec_info, forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu, seq_lens_cpu=self.seq_lens_cpu[:bs],
) )
# Store fields # Store fields

View File

@@ -320,17 +320,30 @@ class ForwardBatch:
# For DP attention # For DP attention
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
spec_num_draft_tokens = (
batch.spec_num_draft_tokens
if batch.spec_num_draft_tokens is not None
else 1
)
global_num_tokens = [
x * spec_num_draft_tokens for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
]
ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor( ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64 global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True) ).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor( ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64 global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True) ).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens) sum_len = sum(global_num_tokens)
ret.gathered_buffer = torch.zeros( ret.gathered_buffer = torch.zeros(
(sum_len, model_runner.model_config.hidden_size), (sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype, dtype=model_runner.dtype,

View File

@@ -163,6 +163,7 @@ class ModelRunner:
logger.addFilter(RankZeroFilter(tp_rank == 0)) logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.dp_size = server_args.dp_size
self.pp_rank = pp_rank self.pp_rank = pp_rank
self.pp_size = pp_size self.pp_size = pp_size
self.dist_port = nccl_port self.dist_port = nccl_port
@@ -196,6 +197,7 @@ class ModelRunner:
| { | {
# TODO it is indeed not a "server args" # TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend, "use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
} }
) )

View File

@@ -22,7 +22,6 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
@@ -77,6 +76,7 @@ class DeepseekModelNextN(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
zero_allocator = BumpAllocator( zero_allocator = BumpAllocator(
buffer_size=2, buffer_size=2,
dtype=torch.float32, dtype=torch.float32,
@@ -90,6 +90,7 @@ class DeepseekModelNextN(nn.Module):
else: else:
hidden_states = input_embeds hidden_states = input_embeds
if hidden_states.shape[0] > 0:
hidden_states = self.eh_proj( hidden_states = self.eh_proj(
torch.cat( torch.cat(
( (
@@ -127,21 +128,12 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self.model = DeepseekModelNextN( self.model = DeepseekModelNextN(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
) )
if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
prefix=add_prefix("model.shared_head.head", prefix),
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("model.shared_head.head", prefix), prefix=add_prefix("model.shared_head.head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)

View File

@@ -1399,7 +1399,9 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
self.layer_id = layer_id self.layer_id = layer_id
self.is_nextn = is_nextn
self.self_attn = DeepseekV2AttentionMLA( self.self_attn = DeepseekV2AttentionMLA(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
@@ -1500,6 +1502,11 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
hidden_states = hidden_states.clone()
return hidden_states, residual return hidden_states, residual
def op_comm_prepare_attn( def op_comm_prepare_attn(

View File

@@ -38,6 +38,10 @@ class EAGLEDraftCudaGraphRunner:
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.topk = model_runner.server_args.speculative_eagle_topk self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -53,7 +57,9 @@ class EAGLEDraftCudaGraphRunner:
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token) self.model_runner.draft_attn_backend.init_cuda_graph_state(
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0 0
].get_cuda_graph_seq_len_fill_value() ].get_cuda_graph_seq_len_fill_value()
@@ -78,10 +84,26 @@ class EAGLEDraftCudaGraphRunner:
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
self.hidden_states = torch.zeros( self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size), (self.max_bs, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture # Capture
try: try:
with model_capture_mode(): with model_capture_mode():
@@ -92,6 +114,21 @@ class EAGLEDraftCudaGraphRunner:
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = ( is_bs_supported = (
forward_batch.batch_size in self.graphs forward_batch.batch_size in self.graphs
if self.disable_padding if self.disable_padding
@@ -116,8 +153,40 @@ class EAGLEDraftCudaGraphRunner:
topk_index = self.topk_index[:num_seqs] topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs] hidden_states = self.hidden_states[:num_seqs]
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
capture_hidden_mode=CaptureHiddenMode.LAST,
) )
# Forward batch # Forward batch
@@ -133,11 +202,14 @@ class EAGLEDraftCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(), seq_lens_sum=seq_lens.sum().item(),
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=( capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
), ),
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
) )
# Attention backend # Attention backend
@@ -147,6 +219,9 @@ class EAGLEDraftCudaGraphRunner:
# Run and capture # Run and capture
def run_once(): def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -184,6 +259,14 @@ class EAGLEDraftCudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad # Pad
if self.enable_dp_attention or self.enable_sp_layernorm:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
@@ -203,6 +286,13 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
# Attention backend # Attention backend
if bs != raw_bs: if bs != raw_bs:
forward_batch.batch_size = bs forward_batch.batch_size = bs
@@ -210,7 +300,9 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs] forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens] forward_batch.positions = self.positions[:num_tokens]
if forward_batch.seq_lens_cpu is not None and bs != raw_bs: # Special handle for seq_len_cpu used when flashinfer mla is used
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs] forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]

View File

@@ -35,6 +35,8 @@ class EAGLEDraftExtendCudaGraphRunner:
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -51,7 +53,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.max_num_token = self.max_bs * self.num_tokens_per_bs self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state( self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
self.max_num_token self.max_bs, self.max_num_token
) )
self.seq_len_fill_value = ( self.seq_len_fill_value = (
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value() self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -90,6 +92,21 @@ class EAGLEDraftExtendCudaGraphRunner:
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
) )
if self.enable_dp_attention or self.enable_sp_layernorm:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture # Capture
try: try:
with model_capture_mode(): with model_capture_mode():
@@ -100,6 +117,21 @@ class EAGLEDraftExtendCudaGraphRunner:
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
return is_bs_supported
else:
batch_size = forward_batch.seq_lens.numel() batch_size = forward_batch.seq_lens.numel()
is_bs_supported = ( is_bs_supported = (
@@ -128,6 +160,35 @@ class EAGLEDraftExtendCudaGraphRunner:
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens] hidden_states = self.hidden_states[:num_tokens]
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput( spec_info = EagleDraftInput(
hidden_states=hidden_states, hidden_states=hidden_states,
accept_length=accept_length, accept_length=accept_length,
@@ -147,6 +208,9 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(), seq_lens_sum=seq_lens.sum().item(),
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_gpu=global_num_tokens,
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -167,6 +231,9 @@ class EAGLEDraftExtendCudaGraphRunner:
# Run and capture # Run and capture
def run_once(): def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
# Backup two fields, which will be modified in-place in `draft_forward`. # Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -203,24 +270,42 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs # in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0] num_tokens = forward_batch.input_ids.shape[0]
if self.enable_dp_attention or self.enable_sp_layernorm:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs * self.num_tokens_per_bs != num_tokens: if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(self.seq_len_fill_value) self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
self.accept_length.fill_(1) self.accept_length.fill_(1)
self.extend_seq_lens.fill_(1)
# Common inputs # Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids) self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
if forward_batch.extend_seq_lens is not None:
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens) self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
self.positions[:num_tokens].copy_(forward_batch.positions) self.positions[:num_tokens].copy_(forward_batch.positions)
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states) self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
if forward_batch.spec_info.accept_length is not None:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu.fill_(self.seq_len_fill_value)

View File

@@ -25,6 +25,8 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
logger = logging.getLogger(__name__)
if is_cuda(): if is_cuda():
from sgl_kernel import ( from sgl_kernel import (
fast_topk, fast_topk,
@@ -69,6 +71,8 @@ class EagleDraftInput:
kv_indices: torch.Tensor = None kv_indices: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch): def prepare_for_extend(self, batch: ScheduleBatch):
if batch.forward_mode.is_idle():
return
# Prefill only generate 1 token. # Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens) assert len(self.verified_id) == len(batch.seq_lens)
@@ -80,6 +84,24 @@ class EagleDraftInput:
) )
pt += extend_len pt += extend_len
@classmethod
def create_idle_input(
cls,
device: torch.device,
hidden_size: int,
topk: int,
capture_hidden_mode: CaptureHiddenMode,
):
return cls(
verified_id=None,
hidden_states=torch.empty(
(0, hidden_size), device=device, dtype=torch.float32
),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
)
def prepare_extend_after_decode( def prepare_extend_after_decode(
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
@@ -193,7 +215,35 @@ class EagleVerifyInput:
seq_lens_cpu: torch.Tensor seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None grammar: BaseGrammarObject = None
@classmethod
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
return cls(
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
retrive_index=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_next_token=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_next_sibling=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_cum_len=None,
topk=topk,
draft_token_num=num_verify_tokens,
spec_steps=spec_steps,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=0,
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
)
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle():
return
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
if page_size == 1: if page_size == 1:
@@ -279,6 +329,25 @@ class EagleVerifyInput:
tokens. I.e., logits_output.next_token_logits only contains tokens. I.e., logits_output.next_token_logits only contains
accepted token logits. accepted token logits.
""" """
if batch.forward_mode.is_idle():
return EagleVerifyOutput(
draft_input=EagleDraftInput.create_idle_input(
device=batch.device,
hidden_size=batch.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
),
logits_output=logits_output,
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
accept_length_per_req_cpu=[],
accepted_indices=torch.full(
(0, self.spec_steps + 1),
-1,
dtype=torch.int32,
device=batch.device,
),
)
bs = self.retrive_index.shape[0] bs = self.retrive_index.shape[0]
candidates = self.draft_token.reshape(bs, self.draft_token_num) candidates = self.draft_token.reshape(bs, self.draft_token_num)
sampling_info = batch.sampling_info sampling_info = batch.sampling_info
@@ -992,6 +1061,7 @@ def select_top_k_tokens(
topk_index = topk_index.reshape(-1, topk**2) topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
if hidden_states.shape[0] > 0:
selected_input_index = topk_cs_index.flatten() // topk + torch.arange( selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda" 0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk) ).repeat_interleave(topk)

View File

@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group from sglang.srt.distributed import (
from sglang.srt.layers.dp_attention import disable_dp_size GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
patch_tensor_parallel_group,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
def draft_tp_context(tp_group: GroupCoordinator): def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group. # Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups. # We disable mscclpp now because it doesn't support 2 comm groups.
with disable_dp_size(), patch_tensor_parallel_group(tp_group): with patch_tensor_parallel_group(tp_group):
yield yield
@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
self.server_args = server_args self.server_args = server_args
self.topk = server_args.speculative_eagle_topk self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.enable_nan_detection = server_args.enable_nan_detection self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.device = server_args.device self.device = server_args.device
@@ -302,32 +307,7 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted, A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens. the batch id (used for overlap schedule), and number of accepted tokens.
""" """
if batch.forward_mode.is_decode(): if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)
# If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None:
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verify_output.verified_id,
model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph,
)
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids, _ = (
self.target_worker.forward_batch_generation(model_worker_batch)
)
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else:
logits_output, next_token_ids, bid, seq_lens_cpu = ( logits_output, next_token_ids, bid, seq_lens_cpu = (
self.forward_target_extend(batch) self.forward_target_extend(batch)
) )
@@ -336,6 +316,51 @@ class EAGLEWorker(TpModelWorker):
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
) )
return logits_output, next_token_ids, bid, 0, False return logits_output, next_token_ids, bid, 0, False
else:
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)
need_forward, can_run_draft_extend_cuda_graph = (
self.check_forward_draft_extend_after_decode(batch)
)
if need_forward:
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(
batch, can_run_draft_extend_cuda_graph
)
return (
logits_output,
verify_output.verified_id,
model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph,
)
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
local_need_forward = (
batch.spec_info.verified_id is not None
and batch.spec_info.verified_id.shape[0] > 0
)
if not self.server_args.enable_dp_attention:
return local_need_forward, True
global_need_forward = torch.tensor(
[
(local_need_forward),
],
dtype=torch.int64,
)
torch.distributed.all_reduce(
global_need_forward, group=get_tp_group().cpu_group
)
global_need_forward_cnt = global_need_forward[0].item()
need_forward = global_need_forward_cnt > 0
can_run_draft_extend_cuda_graph = (
global_need_forward_cnt == get_tensor_model_parallel_world_size()
)
return need_forward, can_run_draft_extend_cuda_graph
def forward_target_extend( def forward_target_extend(
self, batch: ScheduleBatch self, batch: ScheduleBatch
@@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model. # We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.spec_num_draft_tokens = 1
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
@@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch.seq_lens_cpu, model_worker_batch.seq_lens_cpu,
) )
def draft(self, batch: ScheduleBatch): def _draft_preprocess_decode(self, batch: ScheduleBatch):
# Parse args # Parse args
num_seqs = batch.batch_size() num_seqs = batch.batch_size()
spec_info = batch.spec_info spec_info = batch.spec_info
@@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
def _draft_preprocess_idle(self, batch: ScheduleBatch):
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
def draft(self, batch: ScheduleBatch):
# Parse args
if batch.forward_mode.is_idle():
self._draft_preprocess_idle(batch)
else:
self._draft_preprocess_decode(batch)
spec_info = batch.spec_info
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
# Get forward batch # Get forward batch
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.topk
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
@@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker):
forward_batch forward_batch
) )
else: else:
if not forward_batch.forward_mode.is_idle():
# Initialize attention backend # Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch) self.draft_attn_backend.init_forward_metadata(forward_batch)
# Run forward steps # Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch) score_list, token_list, parents_list = self.draft_forward(forward_batch)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup) if batch.forward_mode.is_idle():
return EagleVerifyInput.create_idle_input(
self.topk,
self.speculative_num_steps,
self.speculative_num_draft_tokens,
)
( (
tree_mask, tree_mask,
@@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens_sum, batch.seq_lens_sum,
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens, self.speculative_num_draft_tokens,
) )
return EagleVerifyInput( return EagleVerifyInput(
@@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size) spec_info.prepare_for_verify(batch, self.page_size)
batch.return_hidden_states = False batch.return_hidden_states = False
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = (
ForwardMode.TARGET_VERIFY
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)
batch.spec_info = spec_info batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu seq_lens_cpu_cache=spec_info.seq_lens_cpu
) )
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar: if batch.has_grammar:
@@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker):
self.add_logprob_values(batch, res, logits_output) self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards. # Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = (
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
)
batch.spec_info = res.draft_input batch.spec_info = res.draft_input
return logits_output, res, model_worker_batch, can_run_cuda_graph return logits_output, res, model_worker_batch, can_run_cuda_graph
@@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch( model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu seq_lens_cpu_cache=seq_lens_cpu
) )
model_worker_batch.spec_num_draft_tokens = 1
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
@@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker):
assert forward_batch.spec_info is batch.spec_info assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch): def forward_draft_extend_after_decode(
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
):
# Backup fields that will be modified in-place # Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone() seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob return_logprob_backup = batch.return_logprob
input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata # Prepare metadata
if batch.spec_info.verified_id is not None:
batch.spec_info.prepare_extend_after_decode( batch.spec_info.prepare_extend_after_decode(
batch, batch,
self.speculative_num_steps, self.speculative_num_steps,
) )
else:
batch = batch.copy()
batch.prepare_for_idle()
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
@@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker):
# Run # Run
can_cuda_graph = ( can_cuda_graph = (
self.cuda_graph_runner_for_draft_extend can_run_draft_extend_cuda_graph
and self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
) )
if can_cuda_graph: if can_cuda_graph:
@@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker):
) )
forward_batch.spec_info.hidden_states = logits_output.hidden_states forward_batch.spec_info.hidden_states = logits_output.hidden_states
else: else:
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch) if not forward_batch.forward_mode.is_idle():
self.draft_model_runner.attn_backend.init_forward_metadata(
forward_batch
)
logits_output = self.draft_model_runner.model.forward( logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
@@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker):
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = (
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
)
batch.seq_lens = seq_lens_backup batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup batch.spec_info.accept_length = accept_length_backup

View File

@@ -1,13 +1,19 @@
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase, CustomTestCase,
is_in_amd_ci,
popen_launch_server, popen_launch_server,
) )
@@ -65,5 +71,71 @@ class TestDPAttentionDP2TP2(CustomTestCase):
self.assertGreater(metrics["score"], 0.8) self.assertGreater(metrics["score"], 0.8)
class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--trust-remote-code",
"--disable-radix",
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
"2",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"4",
"--speculative-draft",
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
"--tp-size",
"2",
"--enable-dp-attention",
"--dp-size",
"2",
]
if not is_in_amd_ci():
other_args += ["--mem-frac", "0.7"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(
f"###test_gsm8k (deepseek-v3 mtp + dp):\n"
f"accuracy={metrics['accuracy']=:.3f}\n"
f"{avg_spec_accept_length=:.3f}\n"
)
self.assertGreater(avg_spec_accept_length, 2.5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()