feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
self,
|
||||||
|
max_bs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
|
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
|
||||||
if kv_indices_buf is None:
|
if kv_indices_buf is None:
|
||||||
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
self.cuda_graph_custom_mask = torch.zeros(
|
self.cuda_graph_custom_mask = torch.zeros(
|
||||||
(max_bs * self.max_context_len),
|
(max_num_tokens * self.max_context_len),
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
|
|||||||
"""Init the metadata for a forward pass."""
|
"""Init the metadata for a forward pass."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
"""Init the global shared states for cuda graph."""
|
"""Init the global shared states for cuda graph."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
self,
|
self,
|
||||||
max_bs: int,
|
max_bs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
block_kv_indices: Optional[torch.Tensor] = None,
|
block_kv_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if block_kv_indices is None:
|
if block_kv_indices is None:
|
||||||
|
|||||||
@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
"""Initialize CUDA graph state for the attention backend.
|
"""Initialize CUDA graph state for the attention backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1999,9 +1999,9 @@ class FlashAttentionMultiStepBackend:
|
|||||||
for i in range(self.speculative_num_steps - 1):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
self,
|
||||||
|
max_bs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if kv_indices_buf is None:
|
if kv_indices_buf is None:
|
||||||
cuda_graph_kv_indices = torch.zeros(
|
cuda_graph_kv_indices = torch.zeros(
|
||||||
(max_bs * self.max_context_len,),
|
(max_num_tokens * self.max_context_len,),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
self.cuda_graph_custom_mask = torch.zeros(
|
self.cuda_graph_custom_mask = torch.zeros(
|
||||||
(max_bs * self.max_context_len),
|
(max_num_tokens * self.max_context_len),
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
|
|
||||||
self.common_template(forward_batch, kv_indices, call_fn)
|
self.common_template(forward_batch, kv_indices, call_fn)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||||
|
|||||||
@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
self,
|
||||||
|
max_bs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if kv_indices_buf is None:
|
if kv_indices_buf is None:
|
||||||
cuda_graph_kv_indices = torch.zeros(
|
cuda_graph_kv_indices = torch.zeros(
|
||||||
@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
|
|
||||||
self.common_template(forward_batch, kv_indices, call_fn)
|
self.common_template(forward_batch, kv_indices, call_fn)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
self,
|
self,
|
||||||
max_bs: int,
|
max_bs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
block_kv_indices: Optional[torch.Tensor] = None,
|
block_kv_indices: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if block_kv_indices is None:
|
if block_kv_indices is None:
|
||||||
@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
|
|||||||
|
|
||||||
self.common_template(forward_batch, call_fn)
|
self.common_template(forward_batch, call_fn)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
|
max_bs, max_num_tokens, block_kv_indices=None
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
|
|||||||
@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
|
|||||||
if forward_batch_child.batch_size > 0:
|
if forward_batch_child.batch_size > 0:
|
||||||
child.init_forward_metadata(forward_batch=forward_batch_child)
|
child.init_forward_metadata(forward_batch=forward_batch_child)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
self.primary.init_cuda_graph_state(max_bs=max_bs)
|
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
||||||
for item in self.children:
|
for item in self.children:
|
||||||
# TODO for children, maybe can provide *smaller* max_bs to optimize
|
# TODO for children, maybe can provide *smaller* max_bs to optimize
|
||||||
item.init_cuda_graph_state(max_bs=max_bs)
|
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
num_kv_splits = None
|
num_kv_splits = None
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
attn_lse = None
|
attn_lse = None
|
||||||
|
|
||||||
elif forward_batch.forward_mode.is_draft_extend():
|
elif forward_batch.forward_mode.is_draft_extend():
|
||||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
spec_info.generate_attn_arg_prefill(
|
spec_info.generate_attn_arg_prefill(
|
||||||
@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
self,
|
||||||
|
max_bs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.cuda_graph_attn_logits = torch.zeros(
|
self.cuda_graph_attn_logits = torch.zeros(
|
||||||
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
|
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.cuda_graph_attn_lse = torch.zeros(
|
self.cuda_graph_attn_lse = torch.zeros(
|
||||||
(max_bs, self.num_head, self.max_kv_splits),
|
(max_num_tokens, self.num_head, self.max_kv_splits),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.cuda_graph_num_kv_splits = torch.full(
|
self.cuda_graph_num_kv_splits = torch.full(
|
||||||
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
if kv_indices_buf is None:
|
if kv_indices_buf is None:
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
(max_bs * self.max_context_len),
|
(max_num_tokens * self.max_context_len),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
if not self.skip_prefill:
|
if not self.skip_prefill:
|
||||||
self.cuda_graph_custom_mask = torch.zeros(
|
self.cuda_graph_custom_mask = torch.zeros(
|
||||||
(max_bs * self.max_context_len),
|
(max_num_tokens * self.max_context_len),
|
||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
if kv_indices_buf is None:
|
if kv_indices_buf is None:
|
||||||
self.cuda_graph_window_kv_indices = torch.zeros(
|
self.cuda_graph_window_kv_indices = torch.zeros(
|
||||||
(max_bs * self.sliding_window_size),
|
(max_num_tokens * self.sliding_window_size),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
|
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
|
||||||
|
|
||||||
self.cuda_graph_window_num_kv_splits = torch.full(
|
self.cuda_graph_window_num_kv_splits = torch.full(
|
||||||
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
(max_num_tokens,),
|
||||||
|
self.max_kv_splits,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
custom_mask = self.cuda_graph_custom_mask
|
custom_mask = self.cuda_graph_custom_mask
|
||||||
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||||
mask_indptr = self.mask_indptr[: bs + 1]
|
mask_indptr = self.mask_indptr[: bs + 1]
|
||||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||||
@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
|
|||||||
|
|
||||||
self.common_template(forward_batch, kv_indices, call_fn)
|
self.common_template(forward_batch, kv_indices, call_fn)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
self.cuda_graph_kv_indices = torch.zeros(
|
self.cuda_graph_kv_indices = torch.zeros(
|
||||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||||
|
|||||||
@@ -238,6 +238,10 @@ def _dp_gather(
|
|||||||
assert (
|
assert (
|
||||||
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||||
), "aliasing between global_tokens and local_tokens not allowed"
|
), "aliasing between global_tokens and local_tokens not allowed"
|
||||||
|
if forward_batch.forward_mode.is_draft_extend():
|
||||||
|
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
||||||
|
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
||||||
|
|
||||||
memcpy_triton(
|
memcpy_triton(
|
||||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||||
)
|
)
|
||||||
@@ -288,6 +292,10 @@ def dp_scatter(
|
|||||||
assert (
|
assert (
|
||||||
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||||
), "aliasing between local_tokens and global_tokens not allowed"
|
), "aliasing between local_tokens and global_tokens not allowed"
|
||||||
|
if forward_batch.forward_mode.is_draft_extend():
|
||||||
|
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
||||||
|
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
||||||
|
|
||||||
memcpy_triton(
|
memcpy_triton(
|
||||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -862,6 +862,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
global_num_tokens: Optional[List[int]] = None
|
global_num_tokens: Optional[List[int]] = None
|
||||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
|
is_extend_in_batch: bool = False
|
||||||
tbo_split_seq_index: Optional[int] = None
|
tbo_split_seq_index: Optional[int] = None
|
||||||
global_forward_mode: Optional[ForwardMode] = None
|
global_forward_mode: Optional[ForwardMode] = None
|
||||||
|
|
||||||
@@ -1760,11 +1761,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
decoding_reqs=self.decoding_reqs,
|
decoding_reqs=self.decoding_reqs,
|
||||||
spec_algorithm=self.spec_algorithm,
|
spec_algorithm=self.spec_algorithm,
|
||||||
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
enable_custom_logit_processor=self.enable_custom_logit_processor,
|
||||||
|
global_num_tokens=self.global_num_tokens,
|
||||||
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||||
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||||
|
is_extend_in_batch=self.is_extend_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return (
|
return (
|
||||||
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
|
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
|
||||||
f"#req={(len(self.reqs))})"
|
f"#req={(len(self.reqs))})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1833,6 +1838,7 @@ class ModelWorkerBatch:
|
|||||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||||
# If set, the output of the batch contains the hidden states of the run.
|
# If set, the output of the batch contains the hidden states of the run.
|
||||||
capture_hidden_mode: CaptureHiddenMode = None
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
|
spec_num_draft_tokens: Optional[int] = None
|
||||||
|
|
||||||
# Overlap event
|
# Overlap event
|
||||||
launch_done: Optional[threading.Event] = None
|
launch_done: Optional[threading.Event] = None
|
||||||
|
|||||||
@@ -1350,6 +1350,29 @@ class Scheduler(
|
|||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
self._publish_kv_events()
|
self._publish_kv_events()
|
||||||
|
|
||||||
|
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
|
||||||
|
"""Coordinate the DP attention batch."""
|
||||||
|
|
||||||
|
local_info = torch.tensor(
|
||||||
|
[
|
||||||
|
(new_batch is not None),
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
global_info = torch.empty(
|
||||||
|
(self.server_args.dp_size, self.attn_tp_size, 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
global_info.flatten(),
|
||||||
|
local_info,
|
||||||
|
group=self.tp_cpu_group,
|
||||||
|
)
|
||||||
|
any_new_batch = any(
|
||||||
|
global_info[:, 0, 0].tolist()
|
||||||
|
) # Any DP worker has forward batch
|
||||||
|
return any_new_batch
|
||||||
|
|
||||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||||
# Merge the prefill batch into the running batch
|
# Merge the prefill batch into the running batch
|
||||||
chunked_req_to_exclude = set()
|
chunked_req_to_exclude = set()
|
||||||
@@ -1383,7 +1406,14 @@ class Scheduler(
|
|||||||
self.running_batch.merge_batch(self.last_batch)
|
self.running_batch.merge_batch(self.last_batch)
|
||||||
|
|
||||||
new_batch = self.get_new_batch_prefill()
|
new_batch = self.get_new_batch_prefill()
|
||||||
if new_batch is not None:
|
|
||||||
|
# TODO(ch-wan): minor refactor is needed here to improve readability
|
||||||
|
any_new_batch = (
|
||||||
|
self.server_args.enable_dp_attention
|
||||||
|
and not self.spec_algorithm.is_none()
|
||||||
|
and self.coordinate_spec_dp_attn_batch(new_batch)
|
||||||
|
)
|
||||||
|
if new_batch is not None or any_new_batch:
|
||||||
# Run prefill first if possible
|
# Run prefill first if possible
|
||||||
ret = new_batch
|
ret = new_batch
|
||||||
else:
|
else:
|
||||||
@@ -1732,8 +1762,6 @@ class Scheduler(
|
|||||||
num_tokens_for_logprob = 0
|
num_tokens_for_logprob = 0
|
||||||
elif local_batch.forward_mode.is_decode():
|
elif local_batch.forward_mode.is_decode():
|
||||||
num_tokens = local_batch.batch_size()
|
num_tokens = local_batch.batch_size()
|
||||||
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
|
||||||
num_tokens = num_tokens * speculative_num_draft_tokens
|
|
||||||
num_tokens_for_logprob = num_tokens
|
num_tokens_for_logprob = num_tokens
|
||||||
else:
|
else:
|
||||||
num_tokens = local_batch.extend_num_tokens
|
num_tokens = local_batch.extend_num_tokens
|
||||||
@@ -1809,6 +1837,7 @@ class Scheduler(
|
|||||||
local_batch.global_num_tokens_for_logprob = (
|
local_batch.global_num_tokens_for_logprob = (
|
||||||
global_num_tokens_for_logprob
|
global_num_tokens_for_logprob
|
||||||
)
|
)
|
||||||
|
local_batch.is_extend_in_batch = any(is_extend_in_batch)
|
||||||
local_batch.tbo_split_seq_index = tbo_split_seq_index
|
local_batch.tbo_split_seq_index = tbo_split_seq_index
|
||||||
local_batch.global_forward_mode = global_forward_mode
|
local_batch.global_forward_mode = global_forward_mode
|
||||||
|
|
||||||
@@ -1816,6 +1845,7 @@ class Scheduler(
|
|||||||
if not disable_cuda_graph:
|
if not disable_cuda_graph:
|
||||||
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
||||||
|
|
||||||
|
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
|
||||||
return local_batch, any(is_extend_in_batch)
|
return local_batch, any(is_extend_in_batch)
|
||||||
|
|
||||||
def get_idle_batch(self):
|
def get_idle_batch(self):
|
||||||
|
|||||||
@@ -242,13 +242,13 @@ class CudaGraphRunner:
|
|||||||
# Attention backend
|
# Attention backend
|
||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
if global_server_args_dict["attention_backend"] == "flashmla":
|
self.model_runner.attn_backend.init_cuda_graph_state(
|
||||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
self.max_bs, self.max_num_token
|
||||||
else:
|
)
|
||||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
|
||||||
self.seq_len_fill_value = (
|
self.seq_len_fill_value = (
|
||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||||
self.encoder_len_fill_value = 0
|
self.encoder_len_fill_value = 0
|
||||||
self.seq_lens_cpu = torch.full(
|
self.seq_lens_cpu = torch.full(
|
||||||
@@ -323,12 +323,15 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
total_batch_size = (
|
||||||
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
|
)
|
||||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||||
total_global_tokens in self.graphs
|
total_batch_size in self.graphs
|
||||||
if self.disable_padding
|
if self.disable_padding
|
||||||
else total_global_tokens <= self.max_bs
|
else total_batch_size <= self.max_bs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
is_bs_supported = (
|
is_bs_supported = (
|
||||||
@@ -460,7 +463,7 @@ class CudaGraphRunner:
|
|||||||
self.global_num_tokens_gpu.copy_(
|
self.global_num_tokens_gpu.copy_(
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
[
|
[
|
||||||
num_tokens // self.dp_size + (i < bs % self.dp_size)
|
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||||
for i in range(self.dp_size)
|
for i in range(self.dp_size)
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
@@ -605,9 +608,12 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
index = bisect.bisect_left(
|
total_batch_size = (
|
||||||
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
)
|
)
|
||||||
|
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
||||||
else:
|
else:
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
@@ -650,13 +656,13 @@ class CudaGraphRunner:
|
|||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
self.req_pool_indices,
|
self.req_pool_indices[:bs],
|
||||||
self.seq_lens,
|
self.seq_lens[:bs],
|
||||||
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
||||||
self.encoder_lens,
|
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
||||||
forward_batch.forward_mode,
|
forward_batch.forward_mode,
|
||||||
forward_batch.spec_info,
|
forward_batch.spec_info,
|
||||||
seq_lens_cpu=self.seq_lens_cpu,
|
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store fields
|
# Store fields
|
||||||
|
|||||||
@@ -320,17 +320,30 @@ class ForwardBatch:
|
|||||||
|
|
||||||
# For DP attention
|
# For DP attention
|
||||||
if batch.global_num_tokens is not None:
|
if batch.global_num_tokens is not None:
|
||||||
ret.global_num_tokens_cpu = batch.global_num_tokens
|
|
||||||
|
spec_num_draft_tokens = (
|
||||||
|
batch.spec_num_draft_tokens
|
||||||
|
if batch.spec_num_draft_tokens is not None
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
global_num_tokens = [
|
||||||
|
x * spec_num_draft_tokens for x in batch.global_num_tokens
|
||||||
|
]
|
||||||
|
global_num_tokens_for_logprob = [
|
||||||
|
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
|
||||||
|
]
|
||||||
|
|
||||||
|
ret.global_num_tokens_cpu = global_num_tokens
|
||||||
ret.global_num_tokens_gpu = torch.tensor(
|
ret.global_num_tokens_gpu = torch.tensor(
|
||||||
batch.global_num_tokens, dtype=torch.int64
|
global_num_tokens, dtype=torch.int64
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
|
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
|
||||||
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
||||||
batch.global_num_tokens_for_logprob, dtype=torch.int64
|
global_num_tokens_for_logprob, dtype=torch.int64
|
||||||
).to(device, non_blocking=True)
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
sum_len = sum(batch.global_num_tokens)
|
sum_len = sum(global_num_tokens)
|
||||||
ret.gathered_buffer = torch.zeros(
|
ret.gathered_buffer = torch.zeros(
|
||||||
(sum_len, model_runner.model_config.hidden_size),
|
(sum_len, model_runner.model_config.hidden_size),
|
||||||
dtype=model_runner.dtype,
|
dtype=model_runner.dtype,
|
||||||
|
|||||||
@@ -163,6 +163,7 @@ class ModelRunner:
|
|||||||
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
|
self.dp_size = server_args.dp_size
|
||||||
self.pp_rank = pp_rank
|
self.pp_rank = pp_rank
|
||||||
self.pp_size = pp_size
|
self.pp_size = pp_size
|
||||||
self.dist_port = nccl_port
|
self.dist_port = nccl_port
|
||||||
@@ -196,6 +197,7 @@ class ModelRunner:
|
|||||||
| {
|
| {
|
||||||
# TODO it is indeed not a "server args"
|
# TODO it is indeed not a "server args"
|
||||||
"use_mla_backend": self.use_mla_backend,
|
"use_mla_backend": self.use_mla_backend,
|
||||||
|
"speculative_algorithm": self.spec_algorithm,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import ReplicatedLinear
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
@@ -77,6 +76,7 @@ class DeepseekModelNextN(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
zero_allocator = BumpAllocator(
|
zero_allocator = BumpAllocator(
|
||||||
buffer_size=2,
|
buffer_size=2,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@@ -90,6 +90,7 @@ class DeepseekModelNextN(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states = input_embeds
|
hidden_states = input_embeds
|
||||||
|
|
||||||
|
if hidden_states.shape[0] > 0:
|
||||||
hidden_states = self.eh_proj(
|
hidden_states = self.eh_proj(
|
||||||
torch.cat(
|
torch.cat(
|
||||||
(
|
(
|
||||||
@@ -127,21 +128,12 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
self.model = DeepseekModelNextN(
|
self.model = DeepseekModelNextN(
|
||||||
config, quant_config, prefix=add_prefix("model", prefix)
|
config, quant_config, prefix=add_prefix("model", prefix)
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_server_args_dict["enable_dp_attention"]:
|
|
||||||
self.lm_head = ReplicatedLinear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.vocab_size,
|
|
||||||
bias=False,
|
|
||||||
prefix=add_prefix("model.shared_head.head", prefix),
|
|
||||||
)
|
|
||||||
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
|
||||||
else:
|
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("model.shared_head.head", prefix),
|
prefix=add_prefix("model.shared_head.head", prefix),
|
||||||
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
|
||||||
|
|||||||
@@ -1399,7 +1399,9 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
||||||
|
self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
self.is_nextn = is_nextn
|
||||||
self.self_attn = DeepseekV2AttentionMLA(
|
self.self_attn = DeepseekV2AttentionMLA(
|
||||||
config=config,
|
config=config,
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -1500,6 +1502,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
||||||
|
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
||||||
|
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
||||||
|
hidden_states = hidden_states.clone()
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
def op_comm_prepare_attn(
|
def op_comm_prepare_attn(
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
||||||
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||||
|
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||||
|
self.dp_size = self.model_runner.dp_size
|
||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = self.model_runner.tp_size
|
||||||
self.topk = model_runner.server_args.speculative_eagle_topk
|
self.topk = model_runner.server_args.speculative_eagle_topk
|
||||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||||
@@ -53,7 +57,9 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
# Attention backend
|
# Attention backend
|
||||||
self.max_bs = max(self.capture_bs)
|
self.max_bs = max(self.capture_bs)
|
||||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
|
self.model_runner.draft_attn_backend.init_cuda_graph_state(
|
||||||
|
self.max_bs, self.max_num_token
|
||||||
|
)
|
||||||
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
|
||||||
0
|
0
|
||||||
].get_cuda_graph_seq_len_fill_value()
|
].get_cuda_graph_seq_len_fill_value()
|
||||||
@@ -78,10 +84,26 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
||||||
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
||||||
self.hidden_states = torch.zeros(
|
self.hidden_states = torch.zeros(
|
||||||
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
||||||
dtype=self.model_runner.dtype,
|
dtype=self.model_runner.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
||||||
|
self.gathered_buffer = torch.zeros(
|
||||||
|
(
|
||||||
|
self.max_num_token,
|
||||||
|
self.model_runner.model_config.hidden_size,
|
||||||
|
),
|
||||||
|
dtype=self.model_runner.dtype,
|
||||||
|
)
|
||||||
|
self.global_num_tokens_gpu = torch.zeros(
|
||||||
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||||
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
with model_capture_mode():
|
with model_capture_mode():
|
||||||
@@ -92,6 +114,21 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
if self.enable_dp_attention:
|
||||||
|
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
|
||||||
|
if not forward_batch.can_run_dp_cuda_graph:
|
||||||
|
return False
|
||||||
|
total_batch_size = (
|
||||||
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
|
)
|
||||||
|
is_bs_supported = (
|
||||||
|
total_batch_size in self.graphs
|
||||||
|
if self.disable_padding
|
||||||
|
else total_batch_size <= self.max_bs
|
||||||
|
)
|
||||||
|
else:
|
||||||
is_bs_supported = (
|
is_bs_supported = (
|
||||||
forward_batch.batch_size in self.graphs
|
forward_batch.batch_size in self.graphs
|
||||||
if self.disable_padding
|
if self.disable_padding
|
||||||
@@ -116,8 +153,40 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
topk_index = self.topk_index[:num_seqs]
|
topk_index = self.topk_index[:num_seqs]
|
||||||
hidden_states = self.hidden_states[:num_seqs]
|
hidden_states = self.hidden_states[:num_seqs]
|
||||||
|
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
self.global_num_tokens_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||||
|
for i in range(self.dp_size)
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||||
|
for i in range(self.dp_size)
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||||
|
else:
|
||||||
|
global_num_tokens = None
|
||||||
|
gathered_buffer = None
|
||||||
|
global_num_tokens_for_logprob = None
|
||||||
|
|
||||||
spec_info = EagleDraftInput(
|
spec_info = EagleDraftInput(
|
||||||
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
|
topk_p=topk_p,
|
||||||
|
topk_index=topk_index,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Forward batch
|
# Forward batch
|
||||||
@@ -133,11 +202,14 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
seq_lens_sum=seq_lens.sum().item(),
|
seq_lens_sum=seq_lens.sum().item(),
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
|
global_num_tokens_gpu=global_num_tokens,
|
||||||
|
gathered_buffer=gathered_buffer,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
spec_info=spec_info,
|
spec_info=spec_info,
|
||||||
capture_hidden_mode=(
|
capture_hidden_mode=(
|
||||||
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
||||||
),
|
),
|
||||||
|
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
@@ -147,6 +219,9 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
|
|
||||||
# Run and capture
|
# Run and capture
|
||||||
def run_once():
|
def run_once():
|
||||||
|
# Clean intermediate result cache for DP attention
|
||||||
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
|
|
||||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||||
hidden_states_backup = forward_batch.spec_info.hidden_states
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
||||||
@@ -184,6 +259,14 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
raw_num_token = raw_bs * self.num_tokens_per_bs
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
# Pad
|
# Pad
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
total_batch_size = (
|
||||||
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
|
)
|
||||||
|
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
||||||
|
else:
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
@@ -203,6 +286,13 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
|
||||||
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
|
||||||
|
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
|
forward_batch.global_num_tokens_for_logprob_gpu
|
||||||
|
)
|
||||||
|
forward_batch.gathered_buffer = self.gathered_buffer
|
||||||
|
|
||||||
# Attention backend
|
# Attention backend
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
forward_batch.batch_size = bs
|
forward_batch.batch_size = bs
|
||||||
@@ -210,7 +300,9 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
|
||||||
forward_batch.positions = self.positions[:num_tokens]
|
forward_batch.positions = self.positions[:num_tokens]
|
||||||
|
|
||||||
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
|
# Special handle for seq_len_cpu used when flashinfer mla is used
|
||||||
|
if forward_batch.seq_lens_cpu is not None:
|
||||||
|
if bs != raw_bs:
|
||||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.output_buffers = {}
|
self.output_buffers = {}
|
||||||
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
||||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||||
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
||||||
|
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
||||||
self.tp_size = self.model_runner.tp_size
|
self.tp_size = self.model_runner.tp_size
|
||||||
self.dp_size = model_runner.server_args.dp_size
|
self.dp_size = model_runner.server_args.dp_size
|
||||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||||
@@ -51,7 +53,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||||
|
|
||||||
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
|
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
|
||||||
self.max_num_token
|
self.max_bs, self.max_num_token
|
||||||
)
|
)
|
||||||
self.seq_len_fill_value = (
|
self.seq_len_fill_value = (
|
||||||
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
|
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
@@ -90,6 +92,21 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
self.gathered_buffer = torch.zeros(
|
||||||
|
(
|
||||||
|
self.max_num_token,
|
||||||
|
self.model_runner.model_config.hidden_size,
|
||||||
|
),
|
||||||
|
dtype=self.model_runner.dtype,
|
||||||
|
)
|
||||||
|
self.global_num_tokens_gpu = torch.zeros(
|
||||||
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
||||||
|
(self.dp_size,), dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
with model_capture_mode():
|
with model_capture_mode():
|
||||||
@@ -100,6 +117,21 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def can_run(self, forward_batch: ForwardBatch):
|
def can_run(self, forward_batch: ForwardBatch):
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
if not forward_batch.can_run_dp_cuda_graph:
|
||||||
|
return False
|
||||||
|
total_batch_size = (
|
||||||
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
|
)
|
||||||
|
is_bs_supported = (
|
||||||
|
total_batch_size in self.graphs
|
||||||
|
if self.disable_padding
|
||||||
|
else total_batch_size <= self.max_bs
|
||||||
|
)
|
||||||
|
return is_bs_supported
|
||||||
|
else:
|
||||||
batch_size = forward_batch.seq_lens.numel()
|
batch_size = forward_batch.seq_lens.numel()
|
||||||
|
|
||||||
is_bs_supported = (
|
is_bs_supported = (
|
||||||
@@ -128,6 +160,35 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
hidden_states = self.hidden_states[:num_tokens]
|
hidden_states = self.hidden_states[:num_tokens]
|
||||||
|
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
self.global_num_tokens_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||||
|
for i in range(self.dp_size)
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||||
|
for i in range(self.dp_size)
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
global_num_tokens = self.global_num_tokens_gpu
|
||||||
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
||||||
|
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
|
||||||
|
else:
|
||||||
|
global_num_tokens = None
|
||||||
|
gathered_buffer = None
|
||||||
|
global_num_tokens_for_logprob = None
|
||||||
|
|
||||||
spec_info = EagleDraftInput(
|
spec_info = EagleDraftInput(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
accept_length=accept_length,
|
accept_length=accept_length,
|
||||||
@@ -147,6 +208,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
seq_lens_sum=seq_lens.sum().item(),
|
seq_lens_sum=seq_lens.sum().item(),
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
|
global_num_tokens_gpu=global_num_tokens,
|
||||||
|
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
|
||||||
|
gathered_buffer=gathered_buffer,
|
||||||
spec_algorithm=self.model_runner.spec_algorithm,
|
spec_algorithm=self.model_runner.spec_algorithm,
|
||||||
spec_info=spec_info,
|
spec_info=spec_info,
|
||||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||||
@@ -167,6 +231,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
|
|
||||||
# Run and capture
|
# Run and capture
|
||||||
def run_once():
|
def run_once():
|
||||||
|
# Clean intermediate result cache for DP attention
|
||||||
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
||||||
|
|
||||||
# Backup two fields, which will be modified in-place in `draft_forward`.
|
# Backup two fields, which will be modified in-place in `draft_forward`.
|
||||||
output_cache_loc_backup = forward_batch.out_cache_loc
|
output_cache_loc_backup = forward_batch.out_cache_loc
|
||||||
hidden_states_backup = forward_batch.spec_info.hidden_states
|
hidden_states_backup = forward_batch.spec_info.hidden_states
|
||||||
@@ -203,24 +270,42 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
# in the batch, which will not be counted as num_seqs
|
# in the batch, which will not be counted as num_seqs
|
||||||
raw_bs = forward_batch.batch_size
|
raw_bs = forward_batch.batch_size
|
||||||
num_tokens = forward_batch.input_ids.shape[0]
|
num_tokens = forward_batch.input_ids.shape[0]
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
total_batch_size = (
|
||||||
|
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||||
|
if self.model_runner.spec_algorithm.is_eagle()
|
||||||
|
else sum(forward_batch.global_num_tokens_cpu)
|
||||||
|
)
|
||||||
|
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
||||||
|
else:
|
||||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||||
|
|
||||||
bs = self.capture_bs[index]
|
bs = self.capture_bs[index]
|
||||||
if bs * self.num_tokens_per_bs != num_tokens:
|
if bs * self.num_tokens_per_bs != num_tokens:
|
||||||
self.seq_lens.fill_(self.seq_len_fill_value)
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
||||||
self.out_cache_loc.zero_()
|
self.out_cache_loc.zero_()
|
||||||
self.accept_length.fill_(1)
|
self.accept_length.fill_(1)
|
||||||
|
self.extend_seq_lens.fill_(1)
|
||||||
|
|
||||||
# Common inputs
|
# Common inputs
|
||||||
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
|
||||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||||
|
if forward_batch.extend_seq_lens is not None:
|
||||||
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
|
||||||
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
|
||||||
self.positions[:num_tokens].copy_(forward_batch.positions)
|
self.positions[:num_tokens].copy_(forward_batch.positions)
|
||||||
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
|
||||||
|
if forward_batch.spec_info.accept_length is not None:
|
||||||
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
|
||||||
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
||||||
|
|
||||||
|
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||||
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
||||||
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
||||||
|
forward_batch.global_num_tokens_for_logprob_gpu
|
||||||
|
)
|
||||||
|
forward_batch.gathered_buffer = self.gathered_buffer
|
||||||
|
|
||||||
if forward_batch.seq_lens_cpu is not None:
|
if forward_batch.seq_lens_cpu is not None:
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
|||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if is_cuda():
|
if is_cuda():
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
fast_topk,
|
fast_topk,
|
||||||
@@ -69,6 +71,8 @@ class EagleDraftInput:
|
|||||||
kv_indices: torch.Tensor = None
|
kv_indices: torch.Tensor = None
|
||||||
|
|
||||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||||
|
if batch.forward_mode.is_idle():
|
||||||
|
return
|
||||||
# Prefill only generate 1 token.
|
# Prefill only generate 1 token.
|
||||||
assert len(self.verified_id) == len(batch.seq_lens)
|
assert len(self.verified_id) == len(batch.seq_lens)
|
||||||
|
|
||||||
@@ -80,6 +84,24 @@ class EagleDraftInput:
|
|||||||
)
|
)
|
||||||
pt += extend_len
|
pt += extend_len
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_idle_input(
|
||||||
|
cls,
|
||||||
|
device: torch.device,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
capture_hidden_mode: CaptureHiddenMode,
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
verified_id=None,
|
||||||
|
hidden_states=torch.empty(
|
||||||
|
(0, hidden_size), device=device, dtype=torch.float32
|
||||||
|
),
|
||||||
|
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
||||||
|
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
||||||
|
capture_hidden_mode=capture_hidden_mode,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_extend_after_decode(
|
def prepare_extend_after_decode(
|
||||||
self,
|
self,
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
@@ -193,7 +215,35 @@ class EagleVerifyInput:
|
|||||||
seq_lens_cpu: torch.Tensor
|
seq_lens_cpu: torch.Tensor
|
||||||
grammar: BaseGrammarObject = None
|
grammar: BaseGrammarObject = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
|
||||||
|
return cls(
|
||||||
|
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
|
||||||
|
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
|
||||||
|
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
|
||||||
|
retrive_index=torch.full(
|
||||||
|
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
||||||
|
),
|
||||||
|
retrive_next_token=torch.full(
|
||||||
|
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
||||||
|
),
|
||||||
|
retrive_next_sibling=torch.full(
|
||||||
|
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
|
||||||
|
),
|
||||||
|
retrive_cum_len=None,
|
||||||
|
topk=topk,
|
||||||
|
draft_token_num=num_verify_tokens,
|
||||||
|
spec_steps=spec_steps,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
||||||
|
seq_lens_sum=0,
|
||||||
|
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
||||||
|
|
||||||
|
if batch.forward_mode.is_idle():
|
||||||
|
return
|
||||||
|
|
||||||
batch.input_ids = self.draft_token
|
batch.input_ids = self.draft_token
|
||||||
|
|
||||||
if page_size == 1:
|
if page_size == 1:
|
||||||
@@ -279,6 +329,25 @@ class EagleVerifyInput:
|
|||||||
tokens. I.e., logits_output.next_token_logits only contains
|
tokens. I.e., logits_output.next_token_logits only contains
|
||||||
accepted token logits.
|
accepted token logits.
|
||||||
"""
|
"""
|
||||||
|
if batch.forward_mode.is_idle():
|
||||||
|
return EagleVerifyOutput(
|
||||||
|
draft_input=EagleDraftInput.create_idle_input(
|
||||||
|
device=batch.device,
|
||||||
|
hidden_size=batch.model_config.hidden_size,
|
||||||
|
topk=self.topk,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||||
|
),
|
||||||
|
logits_output=logits_output,
|
||||||
|
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
|
||||||
|
accept_length_per_req_cpu=[],
|
||||||
|
accepted_indices=torch.full(
|
||||||
|
(0, self.spec_steps + 1),
|
||||||
|
-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=batch.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
bs = self.retrive_index.shape[0]
|
bs = self.retrive_index.shape[0]
|
||||||
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
||||||
sampling_info = batch.sampling_info
|
sampling_info = batch.sampling_info
|
||||||
@@ -992,6 +1061,7 @@ def select_top_k_tokens(
|
|||||||
topk_index = topk_index.reshape(-1, topk**2)
|
topk_index = topk_index.reshape(-1, topk**2)
|
||||||
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
||||||
|
|
||||||
|
if hidden_states.shape[0] > 0:
|
||||||
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
||||||
0, hidden_states.shape[0], step=topk, device="cuda"
|
0, hidden_states.shape[0], step=topk, device="cuda"
|
||||||
).repeat_interleave(topk)
|
).repeat_interleave(topk)
|
||||||
|
|||||||
@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
|
from sglang.srt.distributed import (
|
||||||
from sglang.srt.layers.dp_attention import disable_dp_size
|
GroupCoordinator,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
get_tp_group,
|
||||||
|
patch_tensor_parallel_group,
|
||||||
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
|
|||||||
def draft_tp_context(tp_group: GroupCoordinator):
|
def draft_tp_context(tp_group: GroupCoordinator):
|
||||||
# Draft model doesn't use dp and has its own tp group.
|
# Draft model doesn't use dp and has its own tp group.
|
||||||
# We disable mscclpp now because it doesn't support 2 comm groups.
|
# We disable mscclpp now because it doesn't support 2 comm groups.
|
||||||
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
|
with patch_tensor_parallel_group(tp_group):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.topk = server_args.speculative_eagle_topk
|
self.topk = server_args.speculative_eagle_topk
|
||||||
self.speculative_num_steps = server_args.speculative_num_steps
|
self.speculative_num_steps = server_args.speculative_num_steps
|
||||||
|
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
||||||
self.enable_nan_detection = server_args.enable_nan_detection
|
self.enable_nan_detection = server_args.enable_nan_detection
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.device = server_args.device
|
self.device = server_args.device
|
||||||
@@ -302,32 +307,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
A tuple of the final logit output of the target model, next tokens accepted,
|
A tuple of the final logit output of the target model, next tokens accepted,
|
||||||
the batch id (used for overlap schedule), and number of accepted tokens.
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
||||||
"""
|
"""
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
|
||||||
spec_info = self.draft(batch)
|
|
||||||
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
|
||||||
self.verify(batch, spec_info)
|
|
||||||
)
|
|
||||||
|
|
||||||
# If it is None, it means all requests are finished
|
|
||||||
if batch.spec_info.verified_id is not None:
|
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
|
||||||
self.forward_draft_extend_after_decode(batch)
|
|
||||||
return (
|
|
||||||
logits_output,
|
|
||||||
verify_output.verified_id,
|
|
||||||
model_worker_batch.bid,
|
|
||||||
sum(verify_output.accept_length_per_req_cpu),
|
|
||||||
can_run_cuda_graph,
|
|
||||||
)
|
|
||||||
elif batch.forward_mode.is_idle():
|
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
|
||||||
logits_output, next_token_ids, _ = (
|
|
||||||
self.target_worker.forward_batch_generation(model_worker_batch)
|
|
||||||
)
|
|
||||||
|
|
||||||
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
|
|
||||||
else:
|
|
||||||
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
logits_output, next_token_ids, bid, seq_lens_cpu = (
|
||||||
self.forward_target_extend(batch)
|
self.forward_target_extend(batch)
|
||||||
)
|
)
|
||||||
@@ -336,6 +316,51 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
||||||
)
|
)
|
||||||
return logits_output, next_token_ids, bid, 0, False
|
return logits_output, next_token_ids, bid, 0, False
|
||||||
|
else:
|
||||||
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
|
spec_info = self.draft(batch)
|
||||||
|
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
||||||
|
self.verify(batch, spec_info)
|
||||||
|
)
|
||||||
|
need_forward, can_run_draft_extend_cuda_graph = (
|
||||||
|
self.check_forward_draft_extend_after_decode(batch)
|
||||||
|
)
|
||||||
|
if need_forward:
|
||||||
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
|
self.forward_draft_extend_after_decode(
|
||||||
|
batch, can_run_draft_extend_cuda_graph
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
logits_output,
|
||||||
|
verify_output.verified_id,
|
||||||
|
model_worker_batch.bid,
|
||||||
|
sum(verify_output.accept_length_per_req_cpu),
|
||||||
|
can_run_cuda_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||||
|
local_need_forward = (
|
||||||
|
batch.spec_info.verified_id is not None
|
||||||
|
and batch.spec_info.verified_id.shape[0] > 0
|
||||||
|
)
|
||||||
|
if not self.server_args.enable_dp_attention:
|
||||||
|
return local_need_forward, True
|
||||||
|
|
||||||
|
global_need_forward = torch.tensor(
|
||||||
|
[
|
||||||
|
(local_need_forward),
|
||||||
|
],
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
global_need_forward, group=get_tp_group().cpu_group
|
||||||
|
)
|
||||||
|
global_need_forward_cnt = global_need_forward[0].item()
|
||||||
|
need_forward = global_need_forward_cnt > 0
|
||||||
|
can_run_draft_extend_cuda_graph = (
|
||||||
|
global_need_forward_cnt == get_tensor_model_parallel_world_size()
|
||||||
|
)
|
||||||
|
return need_forward, can_run_draft_extend_cuda_graph
|
||||||
|
|
||||||
def forward_target_extend(
|
def forward_target_extend(
|
||||||
self, batch: ScheduleBatch
|
self, batch: ScheduleBatch
|
||||||
@@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
model_worker_batch.spec_num_draft_tokens = 1
|
||||||
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
@@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
model_worker_batch.seq_lens_cpu,
|
model_worker_batch.seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
def draft(self, batch: ScheduleBatch):
|
def _draft_preprocess_decode(self, batch: ScheduleBatch):
|
||||||
# Parse args
|
# Parse args
|
||||||
num_seqs = batch.batch_size()
|
num_seqs = batch.batch_size()
|
||||||
spec_info = batch.spec_info
|
spec_info = batch.spec_info
|
||||||
@@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||||
batch.return_hidden_states = False
|
batch.return_hidden_states = False
|
||||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||||
|
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
||||||
|
|
||||||
|
def _draft_preprocess_idle(self, batch: ScheduleBatch):
|
||||||
|
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||||
|
device=self.device,
|
||||||
|
hidden_size=self.model_config.hidden_size,
|
||||||
|
topk=self.topk,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||||
|
)
|
||||||
|
|
||||||
|
def draft(self, batch: ScheduleBatch):
|
||||||
|
# Parse args
|
||||||
|
if batch.forward_mode.is_idle():
|
||||||
|
self._draft_preprocess_idle(batch)
|
||||||
|
else:
|
||||||
|
self._draft_preprocess_decode(batch)
|
||||||
|
|
||||||
|
spec_info = batch.spec_info
|
||||||
|
|
||||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
batch.return_hidden_states = False
|
||||||
|
|
||||||
# Get forward batch
|
# Get forward batch
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
model_worker_batch.spec_num_draft_tokens = self.topk
|
||||||
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
@@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
forward_batch
|
forward_batch
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if not forward_batch.forward_mode.is_idle():
|
||||||
# Initialize attention backend
|
# Initialize attention backend
|
||||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||||
# Run forward steps
|
# Run forward steps
|
||||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||||
|
|
||||||
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
if batch.forward_mode.is_idle():
|
||||||
|
return EagleVerifyInput.create_idle_input(
|
||||||
|
self.topk,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
self.speculative_num_draft_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
tree_mask,
|
tree_mask,
|
||||||
@@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.seq_lens_sum,
|
batch.seq_lens_sum,
|
||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
self.server_args.speculative_num_draft_tokens,
|
self.speculative_num_draft_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
return EagleVerifyInput(
|
return EagleVerifyInput(
|
||||||
@@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||||
spec_info.prepare_for_verify(batch, self.page_size)
|
spec_info.prepare_for_verify(batch, self.page_size)
|
||||||
batch.return_hidden_states = False
|
batch.return_hidden_states = False
|
||||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
batch.forward_mode = (
|
||||||
|
ForwardMode.TARGET_VERIFY
|
||||||
|
if not batch.forward_mode.is_idle()
|
||||||
|
else ForwardMode.IDLE
|
||||||
|
)
|
||||||
batch.spec_info = spec_info
|
batch.spec_info = spec_info
|
||||||
model_worker_batch = batch.get_model_worker_batch(
|
model_worker_batch = batch.get_model_worker_batch(
|
||||||
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
||||||
)
|
)
|
||||||
|
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
||||||
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
||||||
|
|
||||||
if batch.has_grammar:
|
if batch.has_grammar:
|
||||||
@@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.add_logprob_values(batch, res, logits_output)
|
self.add_logprob_values(batch, res, logits_output)
|
||||||
|
|
||||||
# Prepare the batch for the next draft forwards.
|
# Prepare the batch for the next draft forwards.
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = (
|
||||||
|
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
|
||||||
|
)
|
||||||
batch.spec_info = res.draft_input
|
batch.spec_info = res.draft_input
|
||||||
|
|
||||||
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
||||||
@@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
model_worker_batch = batch.get_model_worker_batch(
|
model_worker_batch = batch.get_model_worker_batch(
|
||||||
seq_lens_cpu_cache=seq_lens_cpu
|
seq_lens_cpu_cache=seq_lens_cpu
|
||||||
)
|
)
|
||||||
|
model_worker_batch.spec_num_draft_tokens = 1
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
@@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
assert forward_batch.spec_info is batch.spec_info
|
assert forward_batch.spec_info is batch.spec_info
|
||||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
|
|
||||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
def forward_draft_extend_after_decode(
|
||||||
|
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
|
||||||
|
):
|
||||||
# Backup fields that will be modified in-place
|
# Backup fields that will be modified in-place
|
||||||
seq_lens_backup = batch.seq_lens.clone()
|
seq_lens_backup = batch.seq_lens.clone()
|
||||||
req_pool_indices_backup = batch.req_pool_indices
|
req_pool_indices_backup = batch.req_pool_indices
|
||||||
accept_length_backup = batch.spec_info.accept_length
|
accept_length_backup = batch.spec_info.accept_length
|
||||||
return_logprob_backup = batch.return_logprob
|
return_logprob_backup = batch.return_logprob
|
||||||
|
|
||||||
|
input_is_idle = batch.forward_mode.is_idle()
|
||||||
|
if not input_is_idle:
|
||||||
# Prepare metadata
|
# Prepare metadata
|
||||||
|
if batch.spec_info.verified_id is not None:
|
||||||
batch.spec_info.prepare_extend_after_decode(
|
batch.spec_info.prepare_extend_after_decode(
|
||||||
batch,
|
batch,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
batch = batch.copy()
|
||||||
|
batch.prepare_for_idle()
|
||||||
|
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||||
|
device=self.device,
|
||||||
|
hidden_size=self.model_config.hidden_size,
|
||||||
|
topk=self.topk,
|
||||||
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch.return_hidden_states = False
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
||||||
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
||||||
forward_batch = ForwardBatch.init_new(
|
forward_batch = ForwardBatch.init_new(
|
||||||
model_worker_batch, self.draft_model_runner
|
model_worker_batch, self.draft_model_runner
|
||||||
)
|
)
|
||||||
@@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
# Run
|
# Run
|
||||||
can_cuda_graph = (
|
can_cuda_graph = (
|
||||||
self.cuda_graph_runner_for_draft_extend
|
can_run_draft_extend_cuda_graph
|
||||||
|
and self.cuda_graph_runner_for_draft_extend
|
||||||
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
||||||
)
|
)
|
||||||
if can_cuda_graph:
|
if can_cuda_graph:
|
||||||
@@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
||||||
else:
|
else:
|
||||||
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
if not forward_batch.forward_mode.is_idle():
|
||||||
|
self.draft_model_runner.attn_backend.init_forward_metadata(
|
||||||
|
forward_batch
|
||||||
|
)
|
||||||
logits_output = self.draft_model_runner.model.forward(
|
logits_output = self.draft_model_runner.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
)
|
)
|
||||||
@@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
# Restore backup.
|
# Restore backup.
|
||||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = (
|
||||||
|
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
||||||
|
)
|
||||||
batch.seq_lens = seq_lens_backup
|
batch.seq_lens = seq_lens_backup
|
||||||
batch.req_pool_indices = req_pool_indices_backup
|
batch.req_pool_indices = req_pool_indices_backup
|
||||||
batch.spec_info.accept_length = accept_length_backup
|
batch.spec_info.accept_length = accept_length_backup
|
||||||
|
|||||||
@@ -1,13 +1,19 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
|
is_in_amd_ci,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,5 +71,71 @@ class TestDPAttentionDP2TP2(CustomTestCase):
|
|||||||
self.assertGreater(metrics["score"], 0.8)
|
self.assertGreater(metrics["score"], 0.8)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDPAttentionDP2TP2DeepseekV3MTP(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
other_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disable-radix",
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"2",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"4",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"4",
|
||||||
|
"--speculative-draft",
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
|
||||||
|
"--tp-size",
|
||||||
|
"2",
|
||||||
|
"--enable-dp-attention",
|
||||||
|
"--dp-size",
|
||||||
|
"2",
|
||||||
|
]
|
||||||
|
if not is_in_amd_ci():
|
||||||
|
other_args += ["--mem-frac", "0.7"]
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
avg_spec_accept_length = server_info.json()["internal_states"][0][
|
||||||
|
"avg_spec_accept_length"
|
||||||
|
]
|
||||||
|
print(
|
||||||
|
f"###test_gsm8k (deepseek-v3 mtp + dp):\n"
|
||||||
|
f"accuracy={metrics['accuracy']=:.3f}\n"
|
||||||
|
f"{avg_spec_accept_length=:.3f}\n"
|
||||||
|
)
|
||||||
|
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user