feat: mtp support dp-attention (#6081)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-17 15:33:28 +08:00
committed by GitHub
parent 8a10c4c3d9
commit 10d60cd41b
22 changed files with 641 additions and 151 deletions

View File

@@ -38,6 +38,10 @@ class EAGLEDraftCudaGraphRunner:
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size
self.topk = model_runner.server_args.speculative_eagle_topk
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -53,7 +57,9 @@ class EAGLEDraftCudaGraphRunner:
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
self.model_runner.draft_attn_backend.init_cuda_graph_state(
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
0
].get_cuda_graph_seq_len_fill_value()
@@ -78,10 +84,26 @@ class EAGLEDraftCudaGraphRunner:
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
self.hidden_states = torch.zeros(
(self.max_num_token, self.model_runner.model_config.hidden_size),
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=self.model_runner.dtype,
)
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture
try:
with model_capture_mode():
@@ -92,11 +114,26 @@ class EAGLEDraftCudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
if self.enable_dp_attention:
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
return is_bs_supported
def capture(self):
@@ -116,8 +153,40 @@ class EAGLEDraftCudaGraphRunner:
topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs]
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput(
topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
topk_p=topk_p,
topk_index=topk_index,
hidden_states=hidden_states,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
# Forward batch
@@ -133,11 +202,14 @@ class EAGLEDraftCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=(
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
),
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
)
# Attention backend
@@ -147,6 +219,9 @@ class EAGLEDraftCudaGraphRunner:
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -184,7 +259,15 @@ class EAGLEDraftCudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
if self.enable_dp_attention or self.enable_sp_layernorm:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value)
@@ -203,6 +286,13 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
# Attention backend
if bs != raw_bs:
forward_batch.batch_size = bs
@@ -210,8 +300,10 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
# Special handle for seq_len_cpu used when flashinfer mla is used
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]

View File

@@ -35,6 +35,8 @@ class EAGLEDraftExtendCudaGraphRunner:
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -51,7 +53,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
self.max_num_token
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = (
self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -90,6 +92,21 @@ class EAGLEDraftExtendCudaGraphRunner:
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
# Capture
try:
with model_capture_mode():
@@ -100,15 +117,30 @@ class EAGLEDraftExtendCudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
batch_size = forward_batch.seq_lens.numel()
if self.enable_dp_attention or self.enable_sp_layernorm:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
return is_bs_supported
else:
batch_size = forward_batch.seq_lens.numel()
is_bs_supported = (
batch_size in self.graphs
if self.disable_padding
else batch_size <= self.max_bs
)
is_bs_supported = (
batch_size in self.graphs
if self.disable_padding
else batch_size <= self.max_bs
)
return is_bs_supported
return is_bs_supported
def capture(self):
CudaGraphRunner.capture(self)
@@ -128,6 +160,35 @@ class EAGLEDraftExtendCudaGraphRunner:
positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens]
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
global_num_tokens_for_logprob = None
spec_info = EagleDraftInput(
hidden_states=hidden_states,
accept_length=accept_length,
@@ -147,6 +208,9 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
global_num_tokens_gpu=global_num_tokens,
global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
gathered_buffer=gathered_buffer,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -167,6 +231,9 @@ class EAGLEDraftExtendCudaGraphRunner:
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
# Backup two fields, which will be modified in-place in `draft_forward`.
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -203,24 +270,42 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0]
if self.enable_dp_attention or self.enable_sp_layernorm:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.accept_length.fill_(1)
self.extend_seq_lens.fill_(1)
# Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
if forward_batch.extend_seq_lens is not None:
self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
self.positions[:num_tokens].copy_(forward_batch.positions)
self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
if forward_batch.spec_info.accept_length is not None:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
if self.enable_dp_attention or self.enable_sp_layernorm:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
)
forward_batch.gathered_buffer = self.gathered_buffer
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)

View File

@@ -25,6 +25,8 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
logger = logging.getLogger(__name__)
if is_cuda():
from sgl_kernel import (
fast_topk,
@@ -69,6 +71,8 @@ class EagleDraftInput:
kv_indices: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch):
if batch.forward_mode.is_idle():
return
# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)
@@ -80,6 +84,24 @@ class EagleDraftInput:
)
pt += extend_len
@classmethod
def create_idle_input(
cls,
device: torch.device,
hidden_size: int,
topk: int,
capture_hidden_mode: CaptureHiddenMode,
):
return cls(
verified_id=None,
hidden_states=torch.empty(
(0, hidden_size), device=device, dtype=torch.float32
),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
)
def prepare_extend_after_decode(
self,
batch: ScheduleBatch,
@@ -193,7 +215,35 @@ class EagleVerifyInput:
seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None
@classmethod
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
return cls(
draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
retrive_index=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_next_token=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_next_sibling=torch.full(
(0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
),
retrive_cum_len=None,
topk=topk,
draft_token_num=num_verify_tokens,
spec_steps=spec_steps,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=0,
seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
)
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle():
return
batch.input_ids = self.draft_token
if page_size == 1:
@@ -279,6 +329,25 @@ class EagleVerifyInput:
tokens. I.e., logits_output.next_token_logits only contains
accepted token logits.
"""
if batch.forward_mode.is_idle():
return EagleVerifyOutput(
draft_input=EagleDraftInput.create_idle_input(
device=batch.device,
hidden_size=batch.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
),
logits_output=logits_output,
verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
accept_length_per_req_cpu=[],
accepted_indices=torch.full(
(0, self.spec_steps + 1),
-1,
dtype=torch.int32,
device=batch.device,
),
)
bs = self.retrive_index.shape[0]
candidates = self.draft_token.reshape(bs, self.draft_token_num)
sampling_info = batch.sampling_info
@@ -992,10 +1061,11 @@ def select_top_k_tokens(
topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
if hidden_states.shape[0] > 0:
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
tree_info = (
expand_scores, # shape: (b, topk, topk)

View File

@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
import torch
from huggingface_hub import snapshot_download
from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
from sglang.srt.layers.dp_attention import disable_dp_size
from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group,
patch_tensor_parallel_group,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import (
@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
def draft_tp_context(tp_group: GroupCoordinator):
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
with disable_dp_size(), patch_tensor_parallel_group(tp_group):
with patch_tensor_parallel_group(tp_group):
yield
@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
self.server_args = server_args
self.topk = server_args.speculative_eagle_topk
self.speculative_num_steps = server_args.speculative_num_steps
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
self.enable_nan_detection = server_args.enable_nan_detection
self.gpu_id = gpu_id
self.device = server_args.device
@@ -302,32 +307,7 @@ class EAGLEWorker(TpModelWorker):
A tuple of the final logit output of the target model, next tokens accepted,
the batch id (used for overlap schedule), and number of accepted tokens.
"""
if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)
# If it is None, it means all requests are finished
if batch.spec_info.verified_id is not None:
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verify_output.verified_id,
model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph,
)
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids, _ = (
self.target_worker.forward_batch_generation(model_worker_batch)
)
return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else:
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
logits_output, next_token_ids, bid, seq_lens_cpu = (
self.forward_target_extend(batch)
)
@@ -336,6 +316,51 @@ class EAGLEWorker(TpModelWorker):
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
)
return logits_output, next_token_ids, bid, 0, False
else:
with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
self.verify(batch, spec_info)
)
need_forward, can_run_draft_extend_cuda_graph = (
self.check_forward_draft_extend_after_decode(batch)
)
if need_forward:
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(
batch, can_run_draft_extend_cuda_graph
)
return (
logits_output,
verify_output.verified_id,
model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph,
)
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
local_need_forward = (
batch.spec_info.verified_id is not None
and batch.spec_info.verified_id.shape[0] > 0
)
if not self.server_args.enable_dp_attention:
return local_need_forward, True
global_need_forward = torch.tensor(
[
(local_need_forward),
],
dtype=torch.int64,
)
torch.distributed.all_reduce(
global_need_forward, group=get_tp_group().cpu_group
)
global_need_forward_cnt = global_need_forward[0].item()
need_forward = global_need_forward_cnt > 0
can_run_draft_extend_cuda_graph = (
global_need_forward_cnt == get_tensor_model_parallel_world_size()
)
return need_forward, can_run_draft_extend_cuda_graph
def forward_target_extend(
self, batch: ScheduleBatch
@@ -354,6 +379,7 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.spec_num_draft_tokens = 1
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch
)
@@ -364,7 +390,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch.seq_lens_cpu,
)
def draft(self, batch: ScheduleBatch):
def _draft_preprocess_decode(self, batch: ScheduleBatch):
# Parse args
num_seqs = batch.batch_size()
spec_info = batch.spec_info
@@ -466,10 +492,32 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
batch.return_hidden_states = False
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
def _draft_preprocess_idle(self, batch: ScheduleBatch):
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
def draft(self, batch: ScheduleBatch):
# Parse args
if batch.forward_mode.is_idle():
self._draft_preprocess_idle(batch)
else:
self._draft_preprocess_decode(batch)
spec_info = batch.spec_info
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
batch.return_hidden_states = False
# Get forward batch
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.topk
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -481,12 +529,18 @@ class EAGLEWorker(TpModelWorker):
forward_batch
)
else:
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
if not forward_batch.forward_mode.is_idle():
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
if batch.forward_mode.is_idle():
return EagleVerifyInput.create_idle_input(
self.topk,
self.speculative_num_steps,
self.speculative_num_draft_tokens,
)
(
tree_mask,
@@ -504,7 +558,7 @@ class EAGLEWorker(TpModelWorker):
batch.seq_lens_sum,
self.topk,
self.speculative_num_steps,
self.server_args.speculative_num_draft_tokens,
self.speculative_num_draft_tokens,
)
return EagleVerifyInput(
@@ -584,11 +638,16 @@ class EAGLEWorker(TpModelWorker):
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
spec_info.prepare_for_verify(batch, self.page_size)
batch.return_hidden_states = False
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.forward_mode = (
ForwardMode.TARGET_VERIFY
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar:
@@ -646,7 +705,9 @@ class EAGLEWorker(TpModelWorker):
self.add_logprob_values(batch, res, logits_output)
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE
batch.forward_mode = (
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
)
batch.spec_info = res.draft_input
return logits_output, res, model_worker_batch, can_run_cuda_graph
@@ -743,6 +804,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = 1
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -753,19 +815,37 @@ class EAGLEWorker(TpModelWorker):
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
def forward_draft_extend_after_decode(
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
):
# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
# Prepare metadata
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata
if batch.spec_info.verified_id is not None:
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
else:
batch = batch.copy()
batch.prepare_for_idle()
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=self.model_config.hidden_size,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -776,7 +856,8 @@ class EAGLEWorker(TpModelWorker):
# Run
can_cuda_graph = (
self.cuda_graph_runner_for_draft_extend
can_run_draft_extend_cuda_graph
and self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
)
if can_cuda_graph:
@@ -789,7 +870,10 @@ class EAGLEWorker(TpModelWorker):
)
forward_batch.spec_info.hidden_states = logits_output.hidden_states
else:
self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
if not forward_batch.forward_mode.is_idle():
self.draft_model_runner.attn_backend.init_forward_metadata(
forward_batch
)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
@@ -799,7 +883,9 @@ class EAGLEWorker(TpModelWorker):
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE
batch.forward_mode = (
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
)
batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.accept_length = accept_length_backup