From dc0705a504fc423cbf38376eb864c898578f7c9a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 9 Jun 2025 16:39:21 -0700 Subject: [PATCH] Simplify prepare_extend_after_decode (#6987) --- python/sglang/srt/managers/schedule_batch.py | 14 +- python/sglang/srt/managers/scheduler.py | 7 +- .../srt/model_executor/cuda_graph_runner.py | 23 +-- python/sglang/srt/server_args.py | 10 +- .../eagle_draft_cuda_graph_runner.py | 4 +- .../eagle_draft_extend_cuda_graph_runner.py | 14 +- python/sglang/srt/speculative/eagle_utils.py | 169 +++++------------- python/sglang/srt/speculative/eagle_worker.py | 71 ++++++-- test/srt/test_full_deepseek_v3.py | 4 +- 9 files changed, 140 insertions(+), 176 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index e61ac7aec..6cf85bdc2 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1636,7 +1636,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if self.spec_info: self.spec_info.merge_batch(other.spec_info) - def get_model_worker_batch(self) -> ModelWorkerBatch: + def get_model_worker_batch( + self, seq_lens_cpu_cache: Optional[torch.Tensor] = None + ) -> ModelWorkerBatch: if self.forward_mode.is_decode_or_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: @@ -1646,16 +1648,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Create seq_lens_cpu when needed if ( - ( + global_server_args_dict["attention_backend"] == "fa3" + or ( global_server_args_dict["use_mla_backend"] and global_server_args_dict["attention_backend"] == "flashinfer" ) or global_server_args_dict["attention_backend"] == "flashmla" - or global_server_args_dict["attention_backend"] == "fa3" or global_server_args_dict["attention_backend"] == "cutlass_mla" or global_server_args_dict["enable_two_batch_overlap"] ): - seq_lens_cpu = self.seq_lens.cpu() + seq_lens_cpu = ( + seq_lens_cpu_cache + if seq_lens_cpu_cache is not None + else self.seq_lens.cpu() + ) else: seq_lens_cpu = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 27223a6a4..a44515ab9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1575,10 +1575,9 @@ class Scheduler( num_accepted_tokens, can_run_cuda_graph, ) = self.draft_worker.forward_batch_speculative_generation(batch) - self.spec_num_total_accepted_tokens += ( - num_accepted_tokens + batch.batch_size() - ) - self.spec_num_total_forward_ct += batch.batch_size() + bs = batch.batch_size() + self.spec_num_total_accepted_tokens += num_accepted_tokens + bs + self.spec_num_total_forward_ct += bs self.num_generated_tokens += num_accepted_tokens if self.pp_group.is_last_rank: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 5583a0884..84af7bc06 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -56,6 +56,16 @@ def get_is_capture_mode(): return is_capture_mode +@contextmanager +def model_capture_mode(): + global is_capture_mode + is_capture_mode = True + + yield + + is_capture_mode = False + + def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): @@ -291,22 +301,13 @@ class CudaGraphRunner: # Capture try: - with self.model_capture_mode(): + with model_capture_mode(): self.capture() except RuntimeError as e: raise Exception( f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) - @contextmanager - def model_capture_mode(self): - global is_capture_mode - is_capture_mode = True - - yield - - is_capture_mode = False - def can_run(self, forward_batch: ForwardBatch): if self.enable_dp_attention or self.enable_sp_layernorm: total_global_tokens = sum(forward_batch.global_num_tokens_cpu) @@ -650,6 +651,8 @@ class CudaGraphRunner: topk=self.model_runner.server_args.speculative_eagle_topk, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, capture_hidden_mode=CaptureHiddenMode.FULL, + seq_lens_sum=None, + seq_lens_cpu=None, ) return spec_info diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 99928f1b7..92a86a0aa 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1013,13 +1013,13 @@ class ServerArgs: type=str, choices=[ "aiter", - "flashinfer", - "triton", - "torch_native", - "fa3", - "flashmla", "cutlass_mla", + "fa3", + "flashinfer", + "flashmla", "intel_amx", + "torch_native", + "triton", ], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 49ca46a99..cfeacbb21 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -10,6 +10,7 @@ from sglang.srt.model_executor.cuda_graph_runner import ( CudaGraphRunner, get_batch_sizes_to_capture, get_global_graph_memory_pool, + model_capture_mode, set_global_graph_memory_pool, set_torch_compile_config, ) @@ -80,7 +81,8 @@ class EAGLEDraftCudaGraphRunner: # Capture try: - self.capture() + with model_capture_mode(): + self.capture() except RuntimeError as e: raise Exception( f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index d6313ca40..cc61e5001 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -11,6 +11,7 @@ from sglang.srt.model_executor.cuda_graph_runner import ( LogitsProcessorOutput, get_batch_sizes_to_capture, get_global_graph_memory_pool, + model_capture_mode, set_global_graph_memory_pool, set_torch_compile_config, ) @@ -19,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.speculative.eagle_utils import EagleDraftInput +from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk if TYPE_CHECKING: from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -37,6 +38,7 @@ class EAGLEDraftExtendCudaGraphRunner: self.tp_size = self.model_runner.tp_size self.dp_size = model_runner.server_args.dp_size self.speculative_num_steps = model_runner.server_args.speculative_num_steps + self.topk = model_runner.server_args.speculative_eagle_topk self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.padded_static_len = -1 @@ -87,7 +89,8 @@ class EAGLEDraftExtendCudaGraphRunner: # Capture try: - self.capture() + with model_capture_mode(): + self.capture() except RuntimeError as e: raise Exception( f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" @@ -170,6 +173,8 @@ class EAGLEDraftExtendCudaGraphRunner: forward_batch.positions, forward_batch, ) + probs = torch.softmax(ret.next_token_logits, dim=-1) + ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1) forward_batch.out_cache_loc = output_cache_loc_backup forward_batch.spec_info.hidden_states = hidden_states_backup @@ -198,7 +203,7 @@ class EAGLEDraftExtendCudaGraphRunner: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] - if bs != raw_bs: + if bs * self.num_tokens_per_bs != num_tokens: self.seq_lens.fill_(1) self.accept_length.fill_(1) self.out_cache_loc.zero_() @@ -238,8 +243,11 @@ class EAGLEDraftExtendCudaGraphRunner: out = self.output_buffers[bs] if bs != raw_bs: forward_batch.spec_info.accept_length = self.accept_length[:raw_bs] + out_copy = out out = LogitsProcessorOutput( next_token_logits=out.next_token_logits[:raw_bs], hidden_states=out.hidden_states[:raw_bs], ) + out.topk_p = out_copy.topk_p[:raw_bs] + out.topk_index = out_copy.topk_index[:raw_bs] return out diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 389eb7442..577e8009b 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -22,8 +22,7 @@ from sglang.srt.managers.schedule_batch import ( global_server_args_dict, ) from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator -from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode -from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 if is_cuda(): @@ -86,78 +85,29 @@ class EagleDraftInput: self, batch: ScheduleBatch, speculative_num_steps: int, - context_length: int, - pad_input: bool = False, ): - accept_length_cpu = batch.spec_info.accept_length_cpu - batch.extend_lens = [x + 1 for x in accept_length_cpu] + batch.forward_mode = ForwardMode.DRAFT_EXTEND + batch.input_ids = self.verified_id + batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] batch.extend_num_tokens = sum(batch.extend_lens) batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend - seq_lens_cpu = batch.seq_lens.tolist() + batch.return_logprob = False - self.positions = torch.empty_like(self.verified_id, dtype=torch.long) - new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) + self.capture_hidden_mode = CaptureHiddenMode.LAST self.accept_length.add_(1) + self.positions = torch.empty_like(batch.input_ids, dtype=torch.long) + self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) - create_extend_spec_info[(self.accept_length.numel(),)]( - self.verified_id, + create_extend_after_decode_spec_info[(len(batch.seq_lens),)]( + batch.input_ids, batch.seq_lens, self.accept_length, - torch.cumsum(self.accept_length, axis=0, dtype=torch.int), self.positions, - new_verified_id, - next_power_of_2(speculative_num_steps + 1), + self.verified_id, + next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))), ) - batch.seq_lens_sum = sum(seq_lens_cpu) - batch.input_ids = self.verified_id - self.verified_id = new_verified_id - - if not pad_input: - return - - batch_size = sum(not req.finished() for req in batch.reqs) - # Total constant input length after padding - static_len = speculative_num_steps + 1 - # Total size after padding - padded_input_size = batch_size * static_len - - padded_len = padded_input_size - batch.input_ids.shape[0] - if padded_len > 0: - new_input_ids = torch.nn.functional.pad( - batch.input_ids, (0, padded_len), value=0 - ) - position_padding = torch.arange(padded_len, device=self.positions.device) - new_positions = torch.cat([self.positions, position_padding]) - - # need dummy hidden states for the padded positions - hidden_states_dim = self.hidden_states.shape[-1] - new_hidden_states = torch.cat( - [ - self.hidden_states, - torch.zeros( - (padded_len, hidden_states_dim), - dtype=self.hidden_states.dtype, - device=self.hidden_states.device, - ), - ], - dim=0, - ) - - # allocate KV cache location for the padded tokens - padded_cache_loc = torch.zeros( - padded_len, - dtype=batch.out_cache_loc.dtype, - device=batch.out_cache_loc.device, - ) - new_out_cache_loc = torch.cat([batch.out_cache_loc, padded_cache_loc]) - - batch.input_ids = new_input_ids - self.hidden_states = new_hidden_states - self.positions = new_positions - batch.out_cache_loc = new_out_cache_loc - def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, @@ -173,8 +123,9 @@ class EagleDraftInput: cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - # TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync. - kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + ) create_flashinfer_kv_indices_triton[(bs,)]( req_to_token, @@ -238,54 +189,10 @@ class EagleVerifyInput: topk: int draft_token_num: int capture_hidden_mode: CaptureHiddenMode + seq_lens_sum: int + seq_lens_cpu: torch.Tensor grammar: BaseGrammarObject = None - @classmethod - def create( - cls, - verified_id: torch.Tensor, - score_list: List[torch.Tensor], - token_list: List[torch.Tensor], - parents_list: List[torch.Tensor], - seq_lens: torch.Tensor, - seq_lens_sum: int, - topk: int, - spec_steps: int, - num_verify_tokens: int, - ): - ( - tree_mask, - position, - retrive_index, - retrive_next_token, - retrive_next_sibling, - draft_tokens, - ) = build_tree_kernel_efficient( - verified_id, - score_list, - token_list, - parents_list, - seq_lens, - seq_lens_sum, - topk, - spec_steps, - num_verify_tokens, - ) - - return cls( - draft_token=draft_tokens, - custom_mask=tree_mask, - positions=position, - retrive_index=retrive_index, - retrive_next_token=retrive_next_token, - retrive_next_sibling=retrive_next_sibling, - retrive_cum_len=None, - spec_steps=spec_steps, - topk=topk, - draft_token_num=num_verify_tokens, - capture_hidden_mode=CaptureHiddenMode.FULL, - ) - def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): batch.input_ids = self.draft_token @@ -614,26 +521,28 @@ class EagleVerifyInput: @triton.jit -def create_extend_spec_info( +def create_extend_after_decode_spec_info( verified_id, - seq_len, - accept_len, - accept_len_cum, + seq_lens, + accept_lens, positions, new_verified_id, - accept_len_upper: tl.constexpr, + bs_upper: tl.constexpr, ): pid = tl.program_id(axis=0) - offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) - seq_length = tl.load(seq_len + pid) - accept_length = tl.load(accept_len + pid) - positions_ptr = positions + offset - data = tl.arange(0, accept_len_upper) - mask = data < accept_length - tl.store(positions_ptr + data, seq_length - accept_length + data, mask) + offsets = tl.arange(0, bs_upper) + seq_length = tl.load(seq_lens + pid) + accept_length = tl.load(accept_lens + pid) - offset = tl.load(accept_len_cum + pid) - 1 - verified_id_data = tl.load(verified_id + offset) + accept_len_cumsum = tl.sum( + tl.load(accept_lens + offsets, mask=offsets < pid, other=0) + ) + positions_ptr = positions + accept_len_cumsum + mask = offsets < accept_length + tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask) + + accept_len_cumsum += accept_length - 1 + verified_id_data = tl.load(verified_id + accept_len_cumsum) tl.store(new_verified_id + pid, verified_id_data) @@ -654,8 +563,8 @@ def assign_req_to_token_pool( token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len length_offset = tl.arange(0, bs_upper) - start = tl.load(start_offset + length_offset, mask=length_offset < pid) - end = tl.load(end_offset + length_offset, mask=length_offset < pid) + start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0) + end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0) out_offset = tl.sum(end - start, axis=0) out_cache_ptr = out_cache_loc + out_offset @@ -736,7 +645,7 @@ def generate_draft_decode_kv_indices( iters += 1 load_offset = tl.arange(0, bs_upper) - seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) + seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0) seq_len = tl.load(paged_kernel_lens + bid) cum_seq_len = tl.sum(seq_lens) @@ -765,7 +674,7 @@ def generate_draft_decode_kv_indices( zid = bid * topk + topk_id if zid == 0: zid = num_seqs * topk - positions = tl.load(positions + bs_offset, mask=bs_offset < zid) + positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0) base = tl.sum(positions) tl.store(kv_indptr + zid, base + zid * iters) @@ -783,7 +692,9 @@ def align_evict_mask_to_page_size( bid = tl.program_id(axis=0) seq_len = tl.load(seq_lens + bid) io_mask = t_range < num_draft_tokens - mask_row = tl.load(evict_mask + bid * num_draft_tokens + t_range, mask=io_mask) + mask_row = tl.load( + evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0 + ) num_trues = tl.sum(mask_row) num_false = num_draft_tokens - num_trues diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index af54a8619..bc0b50f31 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, ) from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, ) @@ -69,7 +70,6 @@ class EAGLEWorker(TpModelWorker): self.server_args = server_args self.topk = server_args.speculative_eagle_topk self.speculative_num_steps = server_args.speculative_num_steps - self.padded_static_len = self.speculative_num_steps + 1 self.enable_nan_detection = server_args.enable_nan_detection self.gpu_id = gpu_id self.device = server_args.device @@ -78,6 +78,7 @@ class EAGLEWorker(TpModelWorker): self.speculative_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) + self.padded_static_len = -1 # Override context length with target model's context length server_args.context_length = target_worker.model_runner.model_config.context_len @@ -184,7 +185,6 @@ class EAGLEWorker(TpModelWorker): self.draft_model_runner, skip_prefill=False, ) - self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = True elif self.server_args.attention_backend == "triton": from sglang.srt.layers.attention.triton_backend import ( @@ -201,7 +201,6 @@ class EAGLEWorker(TpModelWorker): self.draft_model_runner, skip_prefill=False, ) - self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "fa3": from sglang.srt.layers.attention.flashattention_backend import ( @@ -218,7 +217,6 @@ class EAGLEWorker(TpModelWorker): self.draft_model_runner, skip_prefill=False, ) - self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "flashmla": from sglang.srt.layers.attention.flashmla_backend import ( @@ -231,7 +229,6 @@ class EAGLEWorker(TpModelWorker): self.speculative_num_steps, ) self.draft_extend_attn_backend = None - self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = False else: raise ValueError( @@ -319,10 +316,12 @@ class EAGLEWorker(TpModelWorker): return logits_output, next_token_ids, model_worker_batch.bid, 0, False else: - logits_output, next_token_ids, bid = self.forward_target_extend(batch) + logits_output, next_token_ids, bid, seq_lens_cpu = ( + self.forward_target_extend(batch) + ) with self.draft_tp_context(self.draft_model_runner.tp_group): self.forward_draft_extend( - batch, logits_output.hidden_states, next_token_ids + batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu ) return logits_output, next_token_ids, bid, 0, False @@ -346,7 +345,12 @@ class EAGLEWorker(TpModelWorker): logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation( model_worker_batch ) - return logits_output, next_token_ids, model_worker_batch.bid + return ( + logits_output, + next_token_ids, + model_worker_batch.bid, + model_worker_batch.seq_lens_cpu, + ) def draft(self, batch: ScheduleBatch): # Parse args @@ -452,7 +456,14 @@ class EAGLEWorker(TpModelWorker): self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup) - ret = EagleVerifyInput.create( + ( + tree_mask, + position, + retrive_index, + retrive_next_token, + retrive_next_sibling, + draft_tokens, + ) = build_tree_kernel_efficient( spec_info.verified_id, score_list, token_list, @@ -463,7 +474,22 @@ class EAGLEWorker(TpModelWorker): self.speculative_num_steps, self.server_args.speculative_num_draft_tokens, ) - return ret + + return EagleVerifyInput( + draft_token=draft_tokens, + custom_mask=tree_mask, + positions=position, + retrive_index=retrive_index, + retrive_next_token=retrive_next_token, + retrive_next_sibling=retrive_next_sibling, + retrive_cum_len=None, + spec_steps=self.speculative_num_steps, + topk=self.topk, + draft_token_num=self.server_args.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, + seq_lens_sum=forward_batch.seq_lens_sum, + seq_lens_cpu=forward_batch.seq_lens_cpu, + ) def draft_forward(self, forward_batch: ForwardBatch): # Parse args @@ -523,7 +549,9 @@ class EAGLEWorker(TpModelWorker): spec_info.prepare_for_verify(batch, self.page_size) batch.forward_mode = ForwardMode.TARGET_VERIFY batch.spec_info = spec_info - model_worker_batch = batch.get_model_worker_batch() + model_worker_batch = batch.get_model_worker_batch( + seq_lens_cpu_cache=spec_info.seq_lens_cpu + ) if batch.has_grammar: retrieve_next_token_cpu = spec_info.retrive_next_token.cpu() @@ -650,6 +678,7 @@ class EAGLEWorker(TpModelWorker): batch: ScheduleBatch, hidden_states: torch.Tensor, next_token_ids: List[int], + seq_lens_cpu: torch.Tensor, ): """Run draft model extend. This API modifies the states of the batch. @@ -664,7 +693,9 @@ class EAGLEWorker(TpModelWorker): ) batch.spec_info.prepare_for_extend(batch) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - model_worker_batch = batch.get_model_worker_batch() + model_worker_batch = batch.get_model_worker_batch( + seq_lens_cpu_cache=seq_lens_cpu + ) forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) @@ -683,19 +714,18 @@ class EAGLEWorker(TpModelWorker): return_logprob_backup = batch.return_logprob # Prepare metadata - batch.forward_mode = ForwardMode.DRAFT_EXTEND batch.spec_info.prepare_extend_after_decode( batch, self.speculative_num_steps, - self.server_args.context_length, - pad_input=self.cuda_graph_runner_for_draft_extend is not None, ) - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - batch.return_logprob = False model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new( model_worker_batch, self.draft_model_runner ) + if forward_batch.seq_lens_cpu is not None: + forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item() + else: + forward_batch.seq_lens_sum = batch.seq_lens.sum().item() # Run can_cuda_graph = ( @@ -706,14 +736,19 @@ class EAGLEWorker(TpModelWorker): logits_output = self.cuda_graph_runner_for_draft_extend.replay( forward_batch ) + forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = ( + logits_output.topk_p, + logits_output.topk_index, + ) + forward_batch.spec_info.hidden_states = logits_output.hidden_states else: self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch) logits_output = self.draft_model_runner.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) + self.capture_for_decode(logits_output, forward_batch.spec_info) self._detect_nan_if_needed(logits_output) - self.capture_for_decode(logits_output, forward_batch.spec_info) # Restore backup. # This is because `seq_lens` can be modified in `prepare_extend_after_decode` diff --git a/test/srt/test_full_deepseek_v3.py b/test/srt/test_full_deepseek_v3.py index 6a7bb4729..f6a58536a 100644 --- a/test/srt/test_full_deepseek_v3.py +++ b/test/srt/test_full_deepseek_v3.py @@ -87,7 +87,7 @@ class TestDeepseekV3MTP(CustomTestCase): "--speculative-num-steps", "3", "--speculative-eagle-topk", - "2", + "1", "--speculative-num-draft-tokens", "4", ] @@ -155,7 +155,7 @@ class TestDeepseekV3MTP(CustomTestCase): if is_in_amd_ci(): self.assertGreater(speed, 15) else: - self.assertGreater(speed, 105) + self.assertGreater(speed, 130) if __name__ == "__main__":