DP Enhancement (#8280)

This commit is contained in:
Cheng Wan
2025-07-24 21:36:21 -07:00
committed by GitHub
parent 28d4d47280
commit c0fb25e949
20 changed files with 665 additions and 1116 deletions

View File

@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
def forward_batch_speculative_generation(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
self.verify(batch, spec_info)
)
if self.check_forward_draft_extend_after_decode(batch):
with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend_after_decode(
batch,
)
with self.draft_tp_context(self.draft_model_runner.tp_group):
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
if (
self.server_args.enable_dp_attention
or batch.spec_info.verified_id.shape[0] > 0
):
# decode is not finished
self.forward_draft_extend_after_decode(batch)
return (
logits_output,
verify_output.verified_id,
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
)
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
local_need_forward = (
batch.spec_info.verified_id is not None
and batch.spec_info.verified_id.shape[0] > 0
)
local_need_forward = batch.spec_info.verified_id.shape[0] > 0
if not self.server_args.enable_dp_attention:
return local_need_forward
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
def forward_target_extend(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int]:
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
"""Run the target extend.
Args:
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
model_worker_batch.spec_num_draft_tokens = 1
logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch
)
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
self._draft_preprocess_decode(batch)
spec_info = batch.spec_info
assert isinstance(spec_info, EagleDraftInput)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
spec_info.num_tokens_per_batch = self.topk
spec_info.num_tokens_for_logprob_per_batch = self.topk
batch.return_hidden_states = False
# Get forward batch
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.topk
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
forward_batch
)
else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle():
# Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch)
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
def draft_forward(self, forward_batch: ForwardBatch):
# Parse args
spec_info = forward_batch.spec_info
assert isinstance(spec_info, EagleDraftInput)
out_cache_loc = forward_batch.out_cache_loc
topk_p, topk_index, hidden_states = (
spec_info.topk_p,
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
spec_info.hidden_states = hidden_states
# Run forward
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True
)
self._detect_nan_if_needed(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
else ForwardMode.IDLE
)
batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
if batch.has_grammar:
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
self,
batch: ScheduleBatch,
hidden_states: torch.Tensor,
next_token_ids: List[int],
seq_lens_cpu: torch.Tensor,
next_token_ids: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
):
"""Run draft model extend. This API modifies the states of the batch.
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
num_tokens_per_batch=1,
num_tokens_for_logprob_per_batch=1,
)
batch.return_hidden_states = False
batch.spec_info.prepare_for_extend(batch)
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
model_worker_batch.spec_num_draft_tokens = 1
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
@@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker):
self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
assert isinstance(batch.spec_info, EagleDraftInput)
# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
req_pool_indices_backup = batch.req_pool_indices
accept_length_backup = batch.spec_info.accept_length
return_logprob_backup = batch.return_logprob
input_is_idle = batch.forward_mode.is_idle()
if not input_is_idle:
# Prepare metadata
if batch.spec_info.verified_id is not None:
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
else:
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
else self.model_config.hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
else self.model_config.hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
batch.spec_info.num_tokens_for_logprob_per_batch = 1
batch.spec_info.prepare_extend_after_decode(
batch,
self.speculative_num_steps,
)
batch.forward_mode = (
ForwardMode.DRAFT_EXTEND
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)
batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
)
forward_batch.spec_info.hidden_states = logits_output.hidden_states
else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle():
self.draft_model_runner.attn_backend.init_forward_metadata(
forward_batch
)
logits_output = self.draft_model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
logits_output, _ = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True
)
self.capture_for_decode(logits_output, forward_batch.spec_info)